Dissociative and prioritized modeling of behaviorally relevant neural dynamics using recurrent neural networks

Model formulationEquation (1) simplifies the DPAD model by showing both of its RNN sections as one, but the general two-section form of the model is as follows:$$\left\{\begin{array}{c}\left[\begin{array}{c}{x}_{k+1}^{\left(1\right)}\\ {x}_{k+1}^{\left(2\right)}\end{array}\right]=\left[\begin{array}{c}{{A}^{{\prime} }}^{\left(1\right)}\left({x}_{k}^{\left(1\right)}\right)\\ {{A}^{{\prime} }}^{\left(2\right)}\left({x}_{k}^{\left(2\right)}\right)\end{array}\right]+\left[\begin{array}{c}{K}^{\,\left(1\right)}\left(\;{y}_{k}\right)\\ {K}^{\,\left(2\right)}\left({y}_{k},{x}_{k+1}^{\left(1\right)}\right)\end{array}\right]\\ {y}_{k}={C}_{y}^{\,\left(1\right)}\left({x}_{k}^{\left(1\right)}\right)+{C}_{y}^{\,\left(2\right)}\left({x}_{k}^{\left(2\right)}\right)+{e}_{k}\\ {z}_{k}={C}_{z}^{\,\left(1\right)}\left({x}_{k}^{\left(1\right)}\right)+{C}_{z}^{\,\left(2\right)}\left({x}_{k}^{\left(2\right)}\right)+{\epsilon }_{k}\end{array}.\right.$$
(2)
This equation separates the latent states of Eq. (1) into the following two parts: \({x}_{k}^{\left(1\right)}\in {{\mathbb{R}}}^{{n}_{1}}\) denotes the latent states of the first RNN section that summarize the behaviorally relevant dynamics, and \({x}_{k}^{\left(2\right)}\in {{\mathbb{R}}}^{{n}_{2}}\), with \({n}_{2}={n}_{x}-{n}_{1}\), denotes those of the second RNN section that represent the other neural dynamics (Supplementary Fig. 1a). Here, A′(1), A′(2), K(1), K(2), \({C}_{y}^{\,\left(1\right)}\), \({C}_{y}^{\,\left(2\right)}\), \({C}_{z}^{\,\left(1\right)}\) and \({C}_{z}^{\,\left(2\right)}\) are multi-input–multi-output functions that parameterize the model, which we learn using a four-step numerical optimization formulation expanded on in the next section (Supplementary Fig. 1a). DPAD also supports learning the initial value of the latent states at time 0 (that is, \({x}_{0}^{\left(1\right)}\) and \({x}_{0}^{\left(2\right)}\)) as a parameter, but in all analyses in this paper, the initial states are simply set to 0 given their minimal impact when modeling long data sequences. Each pair of superscripted parameters (for example, A′(1) and A′(2)) in Eq. (2) is a dissociated version of the corresponding nonsuperscripted parameter in Eq. (1) (for example, A′). The computation graph for Eq. (2) is provided in Fig. 1b (and Supplementary Fig. 1a). In Eq. (2), the recursions for computing \({x}_{k}^{\left(1\right)}\) are not dependent on \({x}_{k}^{\left(2\right)}\), thus allowing the former to be computed without the latter. By contrast, \({x}_{k}^{\left(2\right)}\) can depend on \({x}_{k}^{\left(1\right)}\), and this dependence is modeled via K(2) (see Supplementary Note 2). Note that such dependence of \({x}_{k}^{\left(2\right)}\) on \({x}_{k}^{\left(1\right)}\) via K(2) does not introduce new dynamics to \({x}_{k}^{\left(2\right)}\) because it does not involve the recursion parameter A′(2), which describes the dynamics of \({x}_{k}^{\left(2\right)}\). This two-section RNN formulation is mathematically motivated by equivalent representations of a dynamical system model in different bases and by the relation between the predictor and stochastic forms of dynamical systems (Supplementary Notes 1 and 2).For the RNN formulated in Eq. (1) or (2), neural activity yk constitutes the input, and predictions of neural and behavioral signals are the outputs (Fig. 1b) given by$$\left\{\begin{array}{c}{\hat{y}}_{k}={C}_{y}\left({x}_{k}\right)\\ {\hat{z}}_{k}={C}_{z}\left({x}_{k}\right)\end{array}.\right.$$
(3)
Note that each xk is estimated purely using all past yk (that is, y1, …, yk – 1), so the predictions in Eq. (3) are one-step-ahead predictions of yk and zk using past neural observations (Supplementary Note 1). Once the model parameters are learned, the extraction of latent states xk involves iteratively applying the first line from Eq. (2), and predicting behavior or neural activity involves applying Eq. (3) to the extracted xk. As such, by writing the nonlinear model in predictor form67,68 (Supplementary Note 1), we enable causal and computationally efficient prediction.Learning: four-step numerical optimization approachBackgroundUnlike nondynamic models1,34,35,36,69, dynamical models explicitly model temporal evolution in time series data. Recent dynamical models have gone beyond linear or generalized linear dynamical models2,3,4,5,6,7,70,71,72,73,74,75,76,77,78,79,80,81 to incorporate switching linear10,11,12,13, locally linear37 or nonlinear14,15,16,17,18,19,20,21,23,24,26,27,38,61,82,83,84,85,86,87,88,89,90 dynamics, often using deep learning methods25,91,92,93,94. But these recent nonlinear/switching works do not aim to localize nonlinearity or allow for flexible nonlinearity and do not enable fully prioritized dissociation of behaviorally relevant neural dynamics because they either do not consider behavior in their learning objective at all14,16,37,38,61,95,96 or incorporate it with a mixed neural–behavioral objective9,18,35,61 (Extended Data Table 1).In DPAD, we develop a four-step learning method for training our two-section RNN in Eq. (1) and extracting the latent states that (1) enables dissociation and prioritized learning of the behaviorally relevant neural dynamics in the nonlinear model, (2) allows for flexible modeling and localization of nonlinearities, (3) extends to data with diverse distributions and (4) does all this while also achieving causal decoding and being applicable to data both with and without a trial structure. DPAD is for nonlinear modeling, and its multistep learning approach, in each step, uses numerical optimization tools that are rooted in deep learning. Thus, DPAD is mathematically distinct from our prior PSID work for linear models, which is an analytical and linear technique. PSID is based on analytical linear algebraic projections rooted in control theory6, which are thus not extendable to nonlinear modeling or to non-Gaussian, noncontinuous or intermittently sampled data. Thus, even when we restrict DPAD to linear modeling as a special case, it is still mathematically different from PSID6.OverviewTo dissociate and prioritize the behaviorally relevant neural dynamics, we devise a four-step optimization approach for learning the two-section RNN model parameters (Supplementary Fig. 1a). This approach prioritizes the extraction and learning of the behaviorally relevant dynamics in the first two steps with states \({x}_{k}^{\left(1\right)}\in {{\mathbb{R}}}^{{n}_{1}}\) while also learning the rest of the neural dynamics in the last two steps with states \({x}_{k}^{\left(2\right)}\in {{\mathbb{R}}}^{{n}_{2}}\) and dissociating the two subtypes of dynamics. This prioritization is important for accurate learning of behaviorally relevant neural dynamics and is achieved because of the multistep learning approach; the earlier steps learn the behaviorally relevant dynamics first, that is, with priority, and then the subsequent steps learn the other neural dynamics later so that they do not mask or confound the behaviorally relevant dynamics. Importantly, each optimization step is independent of subsequent steps so all steps can be performed in order, with no need to iteratively repeat any step. We define the neural and behavioral prediction losses that are used in the optimization steps based on the negative log-likelihoods (NLLs) associated with the neural and behavior distributions, respectively. This approach benefits from the statistical foundation of maximum likelihood estimation and facilitates generalizability across behavioral distributions. We now expand on each of the four optimization steps for RNN training.Optimization step 1In the first two optimization steps (Supplementary Fig. 1a), the objective is to learn the behaviorally relevant latent states \({x}_{k}^{\left(1\right)}\) and their associated parameters. In the first optimization step, we learn the parameters A′(1), \({C}_{z}^{\,\left(1\right)}\) and K(1) of the RNN$$\left\{\begin{array}{c}{x}_{k+1}^{\left(1\right)}={{A}^{{\prime} }}^{\left(1\right)}\left({x}_{k}^{\left(1\right)}\right)+{K}^{\,\left(1\right)}\left(\;{y}_{k}\right)\\ {z}_{k}={C}_{z}^{\,\left(1\right)}\left({x}_{k}^{\left(1\right)}\right)+{\epsilon }_{k}\end{array}\right.$$
(4)
and estimate its latent state \({x}_{k}^{\left(1\right)}\) while minimizing the NLL of the behavior zk given by \({x}_{k}^{\left(1\right)}\). For continuous-valued (Gaussian) behavioral data, we minimize the following sum of squared prediction error69,97 given by$${L}_{z}^{(1)}=\sum _{k}{\left\Vert {z}_{k}-{\hat{z}}_{k}\right\Vert }_{2}^{2}=\sum _{k}{\left\Vert {z}_{k}-{C}_{z}^{\,(1)}({x}_{k}^{(1)})\right\Vert }_{2}^{2}$$
(5)
where the sum is over all available samples of behavior zk, and \({\Vert .\Vert }_{2}\) indicates the two-norm operator. This objective, which is typically used when fitting models to continuous-valued data69,97, is proportional to the Gaussian NLL if we assume isotropic Gaussian residuals (that is, ∑𝜖 = σ𝜖I)69,97. If desired, a general nonisotropic residual covariance ∑𝜖 can be empirically computed from model residuals after the above optimization is solved (see Learning noise statistics), although having ∑𝜖 is mainly useful for simulating new data and is not needed when using the learned model for inference. Similarly, in the subsequent optimization steps detailed later, the same points hold regarding how the appropriate mean squared error used for continuous-valued data is proportional to the Gaussian NLL if we assume isotropic Gaussian residuals and how the residual covariance can be computed empirically after the optimization if desired.Optimization step 2The second optimization step uses the extracted latent state \({x}_{k}^{\left(1\right)}\) from the RNN and fits the parameter \({C}_{y}^{\left(1\right)}\) in$${y}_{k}={C}_{y}^{\,\left(1\right)}\left({x}_{k}^{\left(1\right)}\right)+{e}_{k}$$
(6)
while minimizing the NLL of the neural activity yk given by \({x}_{k}^{(1)}\). For continuous-valued (Gaussian) neural activity yk, we minimize the following sum of squared prediction error69:$${L}_{y}^{(1)}=\sum _{k}{\left\Vert\, {y}_{k}-\hat{y}_{k}\right\Vert }_{2}^{2}=\sum _{k}{\left\Vert\, {y}_{k}-{C}_{y}^{\,(1)}({x}_{k}^{(1)})\right\Vert }_{2}^{2},$$
(7)
where the sum is over all available samples of yk. Optimization steps 1 and 2 conclude the prioritized extraction and modeling of behaviorally relevant latent states \({x}_{k}^{(1)}\) (Fig. 1b) and the learning of the first section of the RNN model (Supplementary Fig. 1a).Optimization step 3In optimization steps 3 and 4 (Supplementary Fig. 1a), the objective is to learn any additional dynamics in neural activity that are not learned in the first two optimization steps, that is, \({x}_{k}^{\left(2\right)}\) and the associated parameters. To do so, in the third optimization step, we learn the parameters A′(2), \({C}_{y}^{\,\left(2\right)}\) and K(2) of the RNN$$\left\{\begin{array}{c}{x}_{k+1}^{\left(2\right)}={{A}^{{\prime} }}^{\left(2\right)}\left({x}_{k}^{\left(2\right)}\right)+{K}^{\,\left(2\right)}\left({y}_{k},{x}_{k+1}^{\left(1\right)}\right)\\ {y}_{k}^{{\prime} }={C}_{y}^{\,\left(2\right)}\left({x}_{k}^{\left(2\right)}\right)+{e}_{k}^{{\prime} }\end{array}\right.$$
(8)
and estimate its latent state \({x}_{k}^{\left(2\right)}\) while minimizing the aggregate NLL of yk given both latent states, that is, by also taking into account the NLL obtained from step 2 via the \({C}_{y}^{\,\left(1\right)}\left({x}_{k}^{\left(1\right)}\right)\) term in Eq. (6). The notations \({y}_{k}^{{\prime} }\) and \({e}_{k}^{{\prime} }\) in the second line of Eq. (8) signify the fact that it is not yk that is predicted by the RNN of Eq. (8), rather it is the yet unpredicted parts of yk (that is, unpredicted after extracting \({x}_{k}^{(1)}\)) that are being predicted. In the case of continuous-valued (Gaussian) neural activity yk, we minimize the following loss:$${L}_{y}^{(2)}=\sum _{k}{\left\Vert\, {y}_{k}-{C}_{y}^{\,(1)}\left({x}_{k}^{(1)}\right)-{C}_{y}^{\,(2)}\left({x}_{k}^{(2)}\right)\right\Vert }_{2}^{2},$$
(9)
where the sum is over all available samples of yk. Note that in the continuous-valued (Gaussian) case, this loss is equivalent to minimizing the error in predicting the residual neural activity given by \({y}_{k}-{C}_{y}^{\,\left(1\right)}\left({x}_{k}^{\left(1\right)}\right)\) and is computed using the previously learned parameter \({C}_{y}^{\,\left(1\right)}\) and the previously extracted states \({x}_{k}^{\left(1\right)}\) in steps 1 and 2. Also, the input to the RNN in Eq. (8) includes both yk and the extracted \({x}_{k+1}^{\left(1\right)}\) from optimization step 1. The above shows how the optimization steps are appropriately linked together to compute the aggregate likelihoods.Optimization step 4If we assume that the second set of states \({x}_{k}^{\left(2\right)}\) do not contain any information about behavior, we could stop the modeling. However, this may not be the case if the dimension of the states extracted in the first optimization step (that is, n1) is selected to be very small such that some behaviorally relevant neural dynamics are not learned in the first step. To be robust to such selections of n1, we can use another final numerical optimization to determine based on the data whether and how \({x}_{k}^{\left(2\right)}\) should affect behavior prediction. Thus, a fourth optimization step uses the extracted latent state in optimization steps 1 and 3 and fits Cz in$${z}_{k}={C}_{z}\left({x}_{k}^{\left(1\right)},{x}_{k}^{\left(2\right)}\right)+{\epsilon }_{k}$$
(10)
while minimizing the negative log-likelihood of behavior given both latent states. In the case of continuous-valued (Gaussian) behavior zk, we minimize the following loss:$${L}_{z}^{(2)}=\sum_{k}{\left\Vert {z}_{k}-\hat{z}_{k}\right\Vert}_{2}^{2}=\sum _{k}{\left\Vert {z}_{k}-{C}_{z}({x}_{k}^{(1)},{x}_{k}^{(2)})\right\Vert }_{2}^{2}.$$
(11)
The parameter Cz that is learned in this optimization step will replace both \({C}_{z}^{\,\left(1\right)}\) and \({C}_{z}^{\,\left(2\right)}\) in Eq. (2). Optionally, in a final optimization step, a similar nonlinear mapping from \({x}_{k}^{\left(1\right)}\) and \({x}_{k}^{\left(2\right)}\) can also be learned, this time to predict yk, which allows DPAD to support nonlinear interactions of \({x}_{k}^{\left(1\right)}\) and \({x}_{k}^{\left(2\right)}\) in predicting neural activity. In this case, the resulting learned Cy parameter will replace both \({C}_{y}^{\,\left(1\right)}\) and \({C}_{y}^{\,\left(2\right)}\) in Eq. (2). This concludes the learning of both model sections (Supplementary Fig. 1a) and all model parameters in Eq. (2).In this work, when optimization steps 1 and 3 are both used to extract the latent states (that is, when 0 < n1 < nx), we do not perform the additional fourth optimization step in Eq. (10), and the prediction of behavior is done solely using the \({x}_{k}^{\left(1\right)}\) states extracted in the first optimization step. Note that DPAD can also cover NDM as a special case if we only use the third optimization step to extract the states (that is, n1 = 0, in which case the first two steps are not needed). In this case, we use the fourth optimization step to learn Cz, which is the mapping from the latent states to behavior. Also, in this case, we simply have a unified state xk as there is no dissociation in NDM, and the only goal is to extract states that predict neural activity accurately.Additional generalizations of state dynamicsFinally, the first lines of Eqs. (4) and (8) can also be written more generally as$${x}_{k+1}^{\left(1\right)}={{A}^{{\prime} {\prime} }}^{\left(1\right)}\left({x}_{k}^{\left(1\right)},{y}_{k}\right)$$
(12)
and$${x}_{k+1}^{\left(2\right)}={{A}^{{\prime} {\prime} }}^{\left(2\right)}\left({x}_{k}^{\left(2\right)},{y}_{k},{x}_{k+1}^{\left(1\right)}\right),$$
(13)
where instead of an additive relation between the two terms of the righthand side, both terms are combined in nonlinear functions \({{A}^{{\prime} {\prime} }}^{\left(1\right)}\) and \({{A}^{{\prime} {\prime} }}^{\left(2\right)}\), which as a special case can still learn the additive relation in Eqs. (4) and (8). Whenever both the state recursion A and neural input K parameters (with the appropriate superscripts) are specified to be nonlinear, we use the more general architecture in Eqs. (12) and (13), and if any one of A or K or both are linear, we use Eqs. (4) and (8).As another option, both RNN sections can be made bidirectional, which enables noncausal prediction for DPAD by using future data in addition to past data, with the goal of improving prediction, especially in datasets with stereotypical trials. Although this option is not reported in this work, it is implemented and available for use in DPAD’s public code library.Learning noise statisticsOnce the learning is complete, we also compute the covariances of the neural and behavior residual time series ek and 𝜖k as ∑e and ∑𝜖, respectively. This allows the learned model in Eq. (1) to be usable for generating new simulated data. This application is not the focus of this work, but an explanation of it is provided in Numerical simulations.RegularizationAdding norm 1 or norm 2 regularization for any set of parameters and the option to automatically select the regularization weight with inner cross-validation is implemented in the DPAD code. However, we did not use regularization in any of the analyses presented here.ForecastingDPAD also enables the capability to predict neural–behavioral data more than one time step into the future. To obtain two-step-ahead prediction, we pass the one-step-ahead neural predictions of the model as neural observations into it. This allows us to perform one state update iteration, that is, line 1 of Eq. (2), with yk being replaced with \({\hat{y}}_{k}\) from Eq. (3). Repeating this procedure m times gives the (m + 1)-step-ahead prediction of the latent state and neural–behavioral data.Extending to intermittently measured behaviorsWe also extend DPAD to modeling intermittently measured behavior time series (Extended Data Figs. 8 and 9 and Supplementary Fig. 8). To do so, when forming the behavior loss (Eqs. (5) and (11)), we only compute the loss on samples where the behavior is measured and solve the optimization with this loss.Extending to noncontinuous-valued data observationsWe can also extend DPAD to noncontinuous-valued (non-Gaussian) observations by devising modified loss functions and observation models. Here, we demonstrate this extension for categorical behavioral observations, for example, discrete choices or epochs/phases during a task (Fig. 7). A similar approach could be used in the future to model other non-Gaussian behaviors and non-Gaussian (for example, Poisson) neural modalities, as shown in a thesis56.To model categorical behaviors, we devise a new behavior observation model for DPAD by making three changes. First, we change the behavior loss (Eqs. (5) and (11)) to the NLL of a categorical distribution, which we implement using the dedicated class in the TensorFlow library (that is, tf.keras.losses.CategoricalCrossentropy). Second, we change the behavior readout parameter Cz to have an output dimension of nz × nc instead of nz, where nc denotes the number of behavior categories or classes. Third, we apply Softmax normalization (Eq. (14)) to the output of the behavior readout parameter Cz to ensure that for each of the nz behavior dimensions, the predicted probabilities for all the nc classes add up to 1 so that they represent valid probability mass functions. Softmax normalization can be written as$${p}_{k}^{\left(m,n\right)}=\frac{\exp \left({l}_{k}^{\,\left(m,n\right)}\right)}{{\sum }_{i=1}^{{n}_{c}}\exp \left({l}_{k}^{\,\left(m,i\right)}\right)},$$
(14)
where \({l}_{k}\in {{\mathbb{R}}}^{{n}_{z}\times {n}_{c}}\) is the output of Cz at time k, and the superscript (m,n) denotes the element of lk associated with the behavior dimension m and the class/category number n. With these changes, we obtain a new RNN architecture with categorical behavioral outputs. We then learn this new RNN architecture with DPAD’s four-step prioritized optimization approach as before but now incorporating the modified NLL losses for categorical data. Together, with these changes, DPAD extends to modeling categorical behavioral measurements.Behavior decoding and neural self-prediction metrics and performance frontierCross-validationTo evaluate the learning, we perform a cross-validation with five folds (unless otherwise noted). We cut the data from the recording session into five equal continuous segments, leave these segments out one by one as the test data and train the model only using the data in the remaining segments. Once the model is trained using the neural and behavior training data, we pass the neural test data to the model to get the latent states in the test data using the first line of Eq. (1) (or Eq. (2), equivalently). We then pass the extracted latent states to Eq. (3) to get the one-step-ahead prediction of the behavior and neural test data, which we refer to as behavior decoding and neural self-prediction, respectively. Note that only past neural data are used to get the behavior and neural predictions. Also, the behavior test data are never used in predictions. Given the predicted behavior and neural time series, we compute the CC between each dimension of these time series and the actual behavior and neural test time series. We then take the mean of CC across dimensions of behavior and neural data to get one final cross-validated CC value for behavior decoding and one final CC value for neural self-prediction in each cross-validation fold.Selection of the latent state dimensionWe often need to select a latent state dimension to report an overall behavior decoding and/or neural self-prediction accuracy for each model/method (for example, Figs. 2–7). By latent state dimension, we always refer to the total latent state dimension of the model, that is, nx. For DPAD, unless otherwise noted, we always used n1 = 16 to extract the first 16 latent state dimensions (or all latent state dimensions when nx ≤ 16) using steps 1 and 2 and any remaining dimensions using steps 3 and 4. We chose n1 = 16 because dedicating more, even all, latent state dimensions to behavior prediction only minimally improved it across datasets and neural modalities. For all methods, to select a state dimension nx, in each cross-validation fold, we fit models with latent state dimensions 1, 2, 4, 16,…and 128 (powers of 2 from 1 to 128) and select one of these models based on their decoding and neural self-prediction accuracies within the training data of that fold. We then report the decoding/self-prediction of this selected model computed in the test data of that fold. Our goal is often to select a model that simultaneously explains behavior and neural data well. For this goal, we pick the state dimension that reaches the peak neural self-prediction in the training data or the state dimension that reaches the peak behavior decoding in the training data, whichever is larger; we then report both the neural self-prediction and the corresponding behavior decoding accuracy of the same model with the selected state dimension in the test data (Figs. 3–4, 6 and 7f, Extended Data Figs. 3 and 4 and Supplementary Figs. 4–7 and 9). Alternatively, for all methods, when our goal is to find models that solely aim to optimize behavior prediction, we report the cross-validated prediction performances for the smallest state dimension that reaches peak behavior decoding in training data (Figs. 2, 5 and 7d, Extended Data Fig. 8 and Supplementary Fig. 3). We emphasize that in all cases, the reported performances are always computed in the test data of the cross-validation fold, which is not used for any other purpose such as model fitting or selection of the state dimension.Performance frontierWhen comparing a group of alternative models, we use the term ‘performance frontier’ to describe the best performances reached by models that in every comparison with any alternative model are in some sense better than or at least comparable to the alternative model. More precisely, when comparing a group \({\mathcal{M}}\) of models, model \({\mathcal{A}}\in {\mathcal{M}}\) will be described as reaching the best performance frontier when compared to every other model \({\mathcal{B}}{\mathscr{\in }}{\mathcal{M}}\), \({\mathcal{A}}\) is significantly better than \({\mathcal{B}}\) in behavior decoding or in neural self-prediction or is comparable to \({\mathcal{B}}\) in both. Note that \({\mathcal{A}}\) may be better than some model \({{\mathcal{B}}}_{1}\in {\mathcal{M}}\) in decoding while being better than another model \({{\mathcal{B}}}_{2}\in {\mathcal{M}}\) in self-prediction; nevertheless \({\mathcal{A}}\) will be on the frontier as long as in every comparison one of the following conditions hold: (1) there is at least one measure for which \({\mathcal{A}}\) is more performant and (2) \({\mathcal{A}}\) is at least equally performant in both measures. To avoid exclusion of models from the best performance frontier due to very minimal performance differences, in this analysis, we only declare a difference in performance significant if in addition to resulting in P ≤ 0.05 in a one-sided signed-rank test there is also at least 1% relative difference in the mean performance measures.DPAD with flexible nonlinearity: automatic determination of appropriate nonlinearityFine-grained control over nonlinearitiesEach parameter in the DPAD model represents an operation in the computation graph of DPAD (Fig. 1b and Supplementary Fig. 1a). We solve the numerical optimizations involved in model learning in each step of our multistep learning via standard stochastic gradient descent43, which remains applicable for any modification of the computation graph that remains acyclic. Thus, the operation associated with each model parameter (for example, A′, K, Cy and Cz) can be replaced with any multilayer neural network with an arbitrary number of hidden units and layers (Supplementary Fig. 1c), and the model remains trainable with the same approach. Having no hidden layers implements the special case of a linear mapping (Supplementary Fig. 1b). Of course, given that the training data are finite, the typical trade-off between model capacity and generalization error remains69. Given that neural networks can approximate any continuous function (with a compact domain)98, replacing model parameters with neural networks should have the capacity to learn any nonlinear function in their place99,100,101. The resulting RNN in Eq. (1) can in turn approximate any state-space dynamics (under mild conditions)102. In this work, for nonlinear parameters, we use multilayer feed-forward networks with one or two hidden layers, each with 64 or 128 units. For all hidden layers, we always use a rectified linear unit (ReLU) nonlinear activation (Supplementary Fig. 1c). Finally, when making a parameter (for example, Cz) nonlinear, we always do so for that parameter in both sections of the RNN (for example, both \({C}_{z}^{\,\left(1\right)}\) and \({C}_{z}^{\,\left(2\right)}\); see Supplementary Fig. 1a) and using the same feed-forward network structure. Given that no existing RNN implementation allowed individual RNN elements to be independently set to arbitrary multilayer neural networks, we developed a custom TensorFlow RNN cell to implement the RNNs in DPAD (Eqs. (4) and (8)). We used the Adam optimizer to implement gradient descent for all optimization steps43. We continued each optimization for up to 2,500 epochs but stopped earlier if the objective function did not improve in three consecutive epochs (convergence criteria).Automatic selection of nonlinearity settingsWe devise a procedure for automatically determining the most suitable combination of nonlinearities for the data, which we refer to as DPAD with flexible nonlinearity. In this procedure, for each cross-validation fold in each recording session of each dataset, we try a series of nonlinearities within the training data and select one based on an inner cross-validation within the training data (Fig. 1d). Specifically, we consider the following options for the nonlinearity. First, each of the four main parameters (that is, A′, K, Cy and Cz) can be linear or nonlinear, resulting in 16 cases (that is, 24). In cases with nonlinearity, we consider four network structures for the parameters, that is, having one or two hidden layers and having 64 or 128 units in each hidden layer (Supplementary Fig. 1c), resulting in 61 cases (that is, 15 × 4 + 1, where 1 is for the fully linear model) overall. Finally, specifically for the recursion parameter A′, we also consider modeling it as an LSTM, with the other parameters still having the same nonlinearity options as before, resulting in another 29 cases for when this LSTM recursion is used (that is, 7 × 4 + 1, where 1 is for the case where the other three model parameters are all linear), bringing the total number of considered cases to 90. For each of these 90 considered linear or nonlinear architectures, we perform a twofold inner cross-validation within the training data to compute an estimate of the behavior decoding and neural self-prediction of each architecture using the training data. Note that although this process for automatic selection of nonlinearities is computationally expensive, it is parallelizable because each candidate model can be fitted independently on a different processor. Once all candidate architectures are fitted and evaluated within the training data, we select one final architecture purely based on training data to be used for that cross-validation fold based on one of the following two criteria: (1) decoding focused: pick the architecture with the best neural self-prediction in training data among all those that reach within 1 s.e.m. of the best behavior decoding; or (2) self-prediction focused: pick the architecture with the best behavior decoding in training data among all those that reach within 1 s.e.m. of the best neural self-prediction. The first criterion prioritizes good behavior decoding in the selection, and the second criterion prioritizes good neural self-prediction. Note that these two criteria are used when selecting among different already-learned models with different nonlinearities and thus are independent of the four internal objective functions used in learning the parameters for a given model with the four-step optimization approach (Supplementary Fig. 1a). For example, in the first optimization step of DPAD, model parameters are always learned to optimize behavior decoding (Eq. (5)). But once the four-step optimization is concluded and different models (with different combinations of nonlinearities) are learned, we can then select among these already-learned models based on either neural self-prediction or behavior decoding. Thus, whenever neural self-prediction is also of interest, we report the results for flexible nonlinearity based on both model selection criteria (for example, Figs. 3, 4 and 6).Localization of nonlinearitiesDPAD enables an inspection of where nonlinearities can be localized to by providing two capabilities, without either of which the origin of nonlinearities may be incorrectly found. As the first capability, DPAD can train alternative models with different individual nonlinearities and then compare these alternative nonlinear models not only with a fully linear model but also with each other and with fully nonlinear models (that is, flexible nonlinearity). Indeed, our simulations showed that simply comparing a linear model to a model with nonlinearity in a given parameter may incorrectly identify the origin of nonlinearity (Extended Data Fig. 2b and Fig. 6a). For example, in Fig. 6a, although the nonlinearity is just in the neural input parameter, a linear model does worse than a model with a nonlinear behavior readout parameter. Thus, just a comparison of the latter model to a linear model would incorrectly find the origin of nonlinearity to be the behavior readout. This issue is avoided in DPAD because it can also train a model with the neural input being nonlinear, thus finding it to be more predictive than models with any other individual nonlinearity and as predictive as a fully nonlinear model (Fig. 6a). As the second capability, DPAD can compare alternative nonlinear models in terms of overall neural–behavioral prediction rather than either behavior decoding or neural prediction alone. Indeed, our simulations showed that comparing the models in terms of just behavior decoding (Extended Data Fig. 2d,f) or just neural self-prediction (Extended Data Fig. 2d,h) may lead to incorrect conclusions about the origin of nonlinearities; this is because a model with the incorrect origin may be equivalent in one of these metrics to the one with the correct origin. DPAD avoids this problem by jointly evaluating both neural–behavioral metrics. Here, when comparing models with nonlinearity in different individual parameters for localization purposes (for example, Fig. 6), we only consider one network architecture for the nonlinearity, that is, having one hidden layer with 64 units.Numerical simulationsTo validate DPAD in numerical simulations, we perform two sets of simulations. One set validates linear modeling to show the correctness of the four-step numerical optimization for learning. The other set validates nonlinear modeling. In the linear simulation, we randomly generate 100 linear models with various dimensionality and noise statistics, as described in our prior work6. Briefly, the neural and behavior dimensions are selected from 5 ≤ ny, nz ≤ 10 randomly with uniform probability. The state dimension is selected as nx = 16, of which n1 = 4 latent state dimensions are selected to drive behavior. Eigenvalues of the state transition matrix are selected randomly as complex conjugate pairs with uniform probability within the unit disk. Each element in the behavior and neural readout matrices is generated as a random Gaussian variable. State and neural observation noise covariances are generated as random positive definite matrices and scaled randomly with a number between 0.003 and 0.3 or between 0.01 and 100, respectively, to obtain a wide range of relative noises across random models. A separate random linear state-space model with four latent state dimensions is generated to produce the behavior readout noise 𝜖k, representing the behavior dynamics that are not encoded in the recorded neural activity. Finally, the behavior readout matrix is scaled to set the ratio of the signal standard deviation to noise standard deviation in each behavior dimension to a random number from 0.5 to 50. We perform model learning and evaluation with twofold cross-validation (Extended Data Fig. 1).In the nonlinear simulations that are used to validate both DPAD and the hypothesis testing procedure it enables to find the origin of nonlinearity, we start by generating 20 random linear models (ny = nz = 1) either with nx = nz = ny (Extended Data Fig. 2) or nx = 2 latent states, only one of which drives behavior (Supplementary Fig. 2). We then introduce nonlinearity in one of the four model parameters (that is, A′, K, Cy or Cz) by replacing that parameter with a nonlinear trigonometric function, such that roughly one period of the trigonometric function is visited by the model (while keeping the rest of the parameters linear). To do this, we first scale each latent state in the initial random linear model to find a similarity transform for it where the latent state has a 95% confidence interval range of 2π. We then add a sine function to the original parameter that is to be changed to nonlinear and scale the amplitude of the sine such that its output reaches roughly 0.25 of the range of the outputs from the original linear parameter. This was done to reduce the chance of generating unrealistic unstable nonlinear models that produce outputs with infinite energy, which is likely when A′ is nonlinear. Changing one parameter to nonlinear can change the range of the statistics of the latent states in the model; thus, we generate some simulated data from the model and redo the scaling of the nonlinearity until ratio conditions are met.To generate data from any nonlinear model in Eq. (1), we first generate a neural noise time series ek based on its covariance ∑e in the model and initialize the state as x0 = 0. We then iteratively apply the second and first lines of Eq. (1) to get the simulated neural activity yk from line 2 and then the next state \({x}_{k+1}\) from line 1, respectively. Finally, once the state time series is produced, we generate a behavior noise time series 𝜖k based on its covariance ∑𝜖 in the model and apply the third line of Eq. (1) to get the simulated behavior zk. Similar to linear simulations, we perform the modeling and evaluation of nonlinear simulations with twofold cross-validation (Extended Data Fig. 2 and Supplementary Fig. 2).Neural datasets and behavioral tasksWe evaluate DPAD in five datasets with different behavioral tasks, brain regions and neural recording modalities to show the generality of our conclusions. For each dataset, all animal procedures were performed in compliance with the National Research Council Guide for Care and Use of Laboratory Animals and were approved by the Institutional Animal Care and Use Committee at the respective institution, namely New York University (datasets 1 and 2)6,45,46, Northwestern University (datasets 3 and 5)47,48,54 and University of California San Francisco (dataset 4)21,49.Across all four main datasets (datasets 1 to 4), the spiking activity was binned with 10-ms nonoverlapping bins, smoothed with a Gaussian kernel with standard deviation of 50 ms (refs. 6,14,34,103,104) and downsampled to 50 ms to be used as the neural signal to be modeled. The behavior time series was also downsampled to a matching 50 ms before modeling. In the three datasets where LFP activity was also available, we also studied two types of features extracted from LFP. As the first LFP feature, we considered raw LFP activity itself, which was high-pass filtered above 0.5 Hz to remove the baseline, low-pass filtered below 10 Hz (that is, antialiasing) and downsampled to the behavior sampling rate of a 50-ms time step (that is, 20 Hz). Note that in the context of the motor cortex, low-pass-filtered raw LFP is also referred to as the local motor potential50,51,52,105,106 and has been used to decode behavior6,50,51,52,53,105,106,107. As the second feature, we computed the LFP log-powers5,6,7,40,77,79,106,108,109 in eight standard frequency bands (delta: 0.1–4 Hz, theta: 4–8 Hz, alpha: 8–12 Hz, low beta: 12–24 Hz, mid-beta: 24–34 Hz, high beta: 34–55 Hz, low gamma: 65–95 Hz and high gamma: 130–170 Hz) in sliding 300-ms windows at a time step of 50 ms using Welch’s method (using eight subwindows with 50% overlap)6. The median analyzed data length for each session across the datasets ranged from 4.6 to 9.9 min.First dataset: 3D reaches to random targetsIn the first dataset, the animal (named J) performed reaches to a target randomly positioned in 3D space within the reach of the animal, grasped the target and returned its hand to resting position6,45. Kinematic data were acquired using the Cortex software package (version 5.3) to track retroreflective markers in 3D (Motion Analysis)6,45. Joint angles were solved from the 3D marker data using a Rhesus macaque musculoskeletal model via the SIMM toolkit (version 4.0, MusculoGraphics)6,45. Angles of 27 joints in the shoulder, elbow, wrist and fingers in the active hand (right hand) were taken as the behavior signal6,45. Neural activity was recorded with a 137-electrode microdrive (Gray Matter Research), of which 28 electrodes were in the contralateral primary motor cortex M1. The multiunit spiking activity in these M1 electrodes was used as the neural signal. For LFP analyses, LFP features were also extracted from the same M1 electrodes. We analyzed the data from seven recording sessions.To visualize the low-dimensional latent state trajectories for each behavioral condition (Extended Data Fig. 6), we determined the periods of reach and return movements in the data (Fig. 7a), resampled them to have similar number of time samples and averaged the latent states across those resampled trials. Given the redundancy in latent descriptions (that is, any scaling, rotation and so on on the latent states still gives an equivalent model), before averaging trials across cross-validation folds and sessions, we devised the following procedure to standardize the latent states for each fold in the case of 2D latent states (Extended Data Fig. 6). (1) We z score all state dimensions to have zero mean and unit variance. (2) We rotate the 2D latent states such that the average 2D state trajectory for the first condition (here, the reach epochs) starts from an angle of 0. (3) We estimate the direction of the rotation for the average 2D state trajectory of the first condition, and if it is not counterclockwise, we multiply the second state dimension by –1 to make it so. Note that in each step, the same mapping is applied to the latent states during the whole test data, regardless of condition, so this procedure does not alter the relative differences in the state trajectory across different conditions. The procedure also does not change the learned model and simply corresponds to a similarity transform that changes the basis of the model. This procedure only removes the redundancies for describing a 2D latent state-space model and standardizes the extracted latent states so that trials across different test sets can be averaged together.Second dataset: saccadic eye movementsIn the second dataset, the animal (named A) performed saccadic eye movements to one of eight targets on a display6,46. The visual stimuli in the task with saccadic eye movements were controlled via custom LabVIEW (version 9.0, National Instruments) software executed on a real-time embedded system (NI PXI-8184, National Instruments)46. The 2D position of the eye was tracked and taken as the behavior signal. Neural activity was recorded with a 32-electrode microdrive (Gray Matter Research) covering the prefrontal cortex6,46. Single-unit activity from these electrodes, ranging from 34 to 43 units across different recording sessions, was used as the neural signal. For LFP analyses, LFP features were also extracted from the same 32 electrodes. We analyzed the data from the first 7 days of recordings. We only included data from successful trials where the animal performed the task correctly by making a saccadic eye movement to the specified target. To visualize the low-dimensional latent state trajectories for each behavioral condition (Extended Data Fig. 6), we grouped the trials based on their target position. Standardization across folds before averaging was done as in the first dataset.Third dataset: sequential reaches with a 2D cursor controlled with a manipulandumIn the third dataset, which was collected and made publicly available by the laboratory of L. E. Miller47,48, the animal (named T) controlled a cursor on a 2D screen using a manipulandum and performed a sequential reach task47,48. The 2D cursor position and velocity were taken as the behavior signal. Neural activity was recorded using a 100-electrode microelectrode array (Blackrock Microsystems) in the dorsal premotor cortex47,48. Single-unit activity, recorded from 37 to 49 units across recording sessions, was used as the neural signal. This dataset did not include any LFP recordings, so LFP features could not be considered. We analyzed the data from all three recording sessions. To visualize the low-dimensional latent state trajectories for each behavioral condition (Extended Data Fig. 6), we grouped the trials into eight different conditions based on the angle of the direction of movement (that is, end position minus starting position) during the trial, with each condition covering movement directions within a 45° (that is, 360/8) range. Standardization across folds before averaging was performed as in the first dataset.Fourth dataset: virtual reality random reaches with a 2D cursor controlled with the fingertipIn the fourth dataset, which was collected and made publicly available by the laboratory of P. N. Sabes49, the animal (named I) controlled a cursor based on the fingertip position on a 2D surface within a 3D virtual reality environment21,49. The 2D cursor position and velocity were taken as the behavior signal. Neural activity was recorded with a 96-electrode microelectrode array (Blackrock Microsystems)21,49 covering M1. We selected a random subset of 32 of these electrodes, which had 77 to 99 single units across the recording sessions, as the neural signal. LFP features were also extracted from the same 32 electrodes. We analyzed the data for the first seven sessions for which the wideband activity was also available (sessions 20160622/01 to 20160921/01). Grouping into conditions for visualization of low-dimensional latent state trajectories (Extended Data Fig. 6) was done as in the third dataset. Standardization across folds before averaging was done as in the first dataset.Fifth dataset: center-out cursor control reaching taskIn the fifth dataset, which was collected and made publicly available by the laboratory of L. E. Miller54, the animal (named H) controlled a cursor on a 2D screen using a manipulandum and performed reaches from a center point to one of eight peripheral targets (Fig. 4i). The 2D cursor position was taken as the behavior signal. Neural activity was recorded with a 96-electrode microelectrode array (Blackrock Microsystems) covering area 2 of the somatosensory cortex54. Preprocessing for this dataset was done as in ref. 36. Specifically, the spiking activity was binned with 1-ms nonoverlapping bins and smoothed with a Gaussian kernel with a standard deviation of 40 ms (ref. 110), with the behavior also being sampled with the same 1-ms sampling rate. Trials were also aligned as in the same prior work110 with data from –100 to 500 ms around movement onset of each trial being used for modeling36.Additional details for baseline methodsFor the fifth dataset, which has been analyzed in ref. 36 and introduces CEBRA, we used the exact same CEBRA hyperparameters as those reported in ref. 36 (Fig. 4i,j). For each of the other four datasets (Fig. 4a–h), when learning a CEBRA-Behavior or CEBRA-Time model for each session, fold and latent dimension, we also performed an extensive search over CEBRA hyperparameters and picked the best value with the same inner cross-validation approach as we use for the automatic selection of nonlinearities in DPAD. We considered 30 different sets of hyperparameters: 3 options for the ‘time-offset’ hyperparameter (1, 2 or 10) and 10 options for the ‘temperature’ hyperparameter (from 0.0001 to 0.01), which were designed to include all sets of hyperparameters reported for primate data in ref. 36. We swept the CEBRA latent dimension over the same values as DPAD, that is, powers of 2 up to 128. In all cases, we used a k-nearest neighbors regression to map the CEBRA-extracted latent embeddings to behavior and neural data as done in ref. 36 because CEBRA itself does not learn a reconstruction model36 (Extended Data Table 1).It is important to note that CEBRA and DPAD have fundamentally different architectures and goals (Extended Data Table 1). CEBRA uses a small ten-sample window (when ‘model_architecture’ is ‘offset10-model’) around each datapoint to extract a latent embedding via a series of convolutions. By contrast, DPAD learns a dynamical model that recursively aggregates all past neural data to extract an embedding. Also, in contrast to CEBRA-Behavior, DPAD’s embedding includes and dissociates both behaviorally relevant neural dimensions and other neural dimensions to predict not only the behavior but also the neural data well. Finally, CEBRA does not automatically map its latent embeddings back to neural data or to behavior during learning but does so post hoc, whereas DPAD learns these mappings for all its latent states. Given these differences, several use-cases of DPAD are not targeted by CEBRA, including explicit dynamical modeling of neural–behavioral data (use-case 1), flexible nonlinearity, hypothesis testing regarding the origin of nonlinearity (use-case 4) and forecasting.StatisticsWe used the Wilcoxon signed-rank test for all paired statistical tests.Reporting summaryFurther information on research design is available in the Nature Portfolio Reporting Summary linked to this article.

Hot Topics

Related Articles