Skip to content

LFADS Hyperparameters

The following is an nearly exhaustive list of hyperparameters that affect the training and posterior mean sampling of the model. If the parameter name begins with c_, this parameter will be passed to the LFADS Python code directly, without the c_ prefix. All of these values are specified in the RunParams instance that accompanies each Run.

While there are many parameters, you will likely care about only a small subset of them. We have color-coded the hyperparameters below according to how often they require tuning:

Tuning Frequency Description
Common Typically requires tuning on a per-project basis and/or is important to set appropriately upfront.
Occasional Might be adjusted for fine tuning.
Rare Infrequently requires tuning and/or primarily intended for advanced users.

Run Manager logistics and data Processing

Name Default Description
name '' Name of this set of parameters, used for convenience only. Note that this parameter does not affect either the param or data hash.
version n/a This value you should not assign directly, as it will automatically be set to match the version of the RunCollection to which it is added. This is used for graceful backwards compatibility. Note that this parameter does not affect either the param or data hash.
spikeBinMs 2 Spike bin width in milliseconds. This must be an integer multiple of the original bin width provided by the `Run` class by `generateCountsForDataset`.

TensorFlow logistics

Name Default Description
c_allow_gpu_growth true Allow the GPU to dynamically allocate memory instead of allocating all the GPU's memory at the start
c_max_ckpt_to_keep 5 Max number of checkpoints to keep (rolling)
c_max_ckpt_to_keep_lve 5 Max number of checkpoints to keep for lowest validation error models (rolling)
c_device 'gpu:0' Which visible GPU or CPU to use. Note that GPUs are typically scheduled by setting `CUDA_VISIBLE_DEVICES` rather than using this parameter.

Optimization

Rather put the learning rate on an exponentially decreasing schedule, the current algorithm pays attention to the learning rate, and if it isn’t regularly decreasing, it will decrease the learning rate. So far, it works fine, though it is not perfect.

Name Default Description
c_learning_rate_init 0.01 Initial learning rate
c_learning_rate_decay_factor 0.95 Factor by which to decrease the learning rate if progress isn't being made.
c_learning_rate_n_to_compare 6 Number of previous costs current cost has to be worse than, in order to lower learning rate.
c_learning_rate_stop 0.00001 Stop training when the learning rate reaches this threshold.
c_max_grad_norm 200 Max norm of gradient before clipping. This sets a value, above which, the gradients will be clipped. This hp is extremely useful to avoid an infrequent, but highly pathological problem whereby the gradient is so large that it destroys the optimization by setting parameters too large, leading to a vicious cycle that ends in NaNs. If it's too large, it's useless, if it's too small, it essentially becomes the learning rate. It's pretty insensitive, though.
trainToTestRatio 4 Ratio of training vs testing trials used.
c_batch_size 256 Number of trials to use during each training pass. The total trial count must be ≥ c_batch_size * (trainToTestRatio + 1).
c_cell_clip_value 5 Max value recurrent cell can take before being clipped. If your optimizations start "NaN-ing out", reduce this value so that the values of the network don't grow out of control. Typically, once this parameter is set to a reasonable value, one stops having numerical problems.

Overfitting

If controller is heavily penalized, then it won’t have any output. If dynamics are heavily penalized, then generator won’t make dynamics. Note this l2 penalty is only on the recurrent portion of the RNNs, as dropout is also available, penalizing the feed-forward connections.

Name Default Description
c_temporal_spike_jitter_width 0 Enables jittering spike times during training. It appears that the system will happily fit spikes (blessing or curse, depending). You may not want this. Jittering the spikes a bit may help (-/+ bin size, as specified here). The idea is to prevent LFADS from trying to learn very fine temporal structure in the data if you believe this to be noise.
c_keep_prob 0.95 Fraction of units to randomly drop during each training pass. Dropout is done on the input data, on controller inputs (from encoder), and on outputs from generator to factors.
c_l2_gen_scale 500 L2 regularization cost for the generator only.
c_co_mean_corr_scale 0 Cost of correlation (through time) in the means of controller output.

Underfitting

If the primary task of LFADS is “filtering” of data and not generation, then it is possible that the KL penalty is too strong. Empirically, we have found this to be the case. So we add a hyperparameter in front of the the two KL terms (one for the initial conditions to the generator, the other for the controller outputs). You should always think of the the default values as 1.0, and that leads to a standard VAE formulation whereby the numbers that are optimized are a lower-bound on the log-likelihood of the data. When these 2 HPs deviate from 1.0, one cannot make any statement about what those LL lower bounds mean anymore, and they cannot be compared (AFAIK).

Sometimes the task can be sufficiently hard to learn that the optimizer takes the ‘easy route’, and simply minimizes the KL divergence, setting it to near zero, and the optimization gets stuck. The same possibility is true for the L2 regularizer. One wants a simple generator, for scientific reasons, but not at the expense of hosing the optimization. The last 5 parameters help avoid that by by getting the optimization to ‘latch’ on to the main optimization, and only turning on the regularizers gradually by increasing their weighting in the overall cost functions later.

Name Default Description
c_kl_ic_weight 1 Strength of KL weight on initial conditions KL penalty.
c_kl_co_weight 1 Strength of KL weight on controller output KL penalty.
c_kl_start_step 0 Start increasing KL weight after this many steps.
c_kl_increase_steps 900 Number of steps over which the KL weight increases.
c_l2_start_step 0 Start increasing L2 weight after this many steps.
c_l2_increase_steps 900 Number of steps over which the L2 weight increases.
c_l2_start_step 0 Start increasing L2 weight after this many steps
scaleIncreaseStepsWithDatasets true If true, c_kl_increase_steps and c_l2_increase_steps will be multiplied by the number of datasets in a stitching run.

External inputs

If there are observed inputs, there are two ways to add that observed input to the model. The first is by treating as something to be inferred, and thus encoding the observed input via the encoders, and then input to the generator via the “inferred inputs” channel. Second, one can input the input directly into the generator. This has the downside of making the generation process strictly dependent on knowing the observed input for any generated trial.

Name Default Description
c_ext_input_dim 0 Number of external, known (or observed) inputs.
c_inject_ext_input_to_gen false Should the known inputs be input to model via encoders (false) or injected directly into generator (true)?

Controller and inferred inputs

The controller will be more powerful if it can see the encoding of the entire trial. However, this allows the controller to create inferred inputs that are acausal with respect to the actual data generation process. For example, the data generator could have an input at time t, but the controller, after seeing the entirety of the trial could infer that the input is coming a little before time t, because there are no restrictions on the data the controller sees. One can force the controller to be causal (with respect to perturbations in the data generator) so that it only sees forward encodings of the data at time t that originate at times before or at time t. One can also control the data the controller sees by using an input lag (forward encoding at time t-tlag for controller input at time t. The same can be done in the reverse direction (controller input at time t from reverse encoding at time t+tlag, in the case of an acausal controller). Setting this lag > 0 (even lag=1) can be a powerful way of avoiding very spiky decodes. Finally, one can manually control whether the factors at time t-1 are fed to the controller at time t.

If you don’t care about any of this, and just want to smooth your data, set do_causal_controller = False, do_feed_factors_to_controller = True, controller_input_lag = 0.

Name Default Description
c_co_dim 4 Number of inferred inputs (controller outputs). This parameter critically controls whether or not there is a controller (along with controller encoders placed into the LFADS graph. If equal to 0, no controller will be added.
c_prior_ar_atau 10 Initial autocorrelation of AR(1) priors (in time bins)
c_do_train_prior_ar_atau true Is the value for atau an initial value (true) or the constant value (false)?
c_prior_ar_nvar 0.1 Initial noise variance for AR(1) priors
c_do_train_prior_ar_nvar true Is the value for the noise var an initial value (true) or the constant value (false)?
c_do_causal_controller false Restrict input encoder from seeing the future?
c_do_feed_factors_to_controller true Should factors[t-1] be input to controller at time t? Strictly speaking, feeding either the factors or the rates to the controller violates causality, since the g0 gets to see all the data. This may or may not be only a theoretical concern.
c_feedback_factors_or_rates 'factors' Feedback the factors or the rates to the controller? Set to either 'factors' or 'rates'
c_controller_input_lag 1 Time lag on the encoding to controller t-lag for forward, t+lag for reverse.
c_ci_enc_dim 128 Network size for controller input encoder.
c_con_dim 128 Controller dimensionality.
c_co_prior_var_scale 0.1 Variance of control input prior distribution.

Encoder and initial conditions for generator

Note that the dimension of the initial conditions is separated from the dimensions of the generator initial conditions (and a linear matrix will adapt the shapes if necessary). This is just another way to control complexity. In all likelihood, setting the IC dims to the size of the generator hidden state is just fine.

For the initial condition prior variance parameters, it’s best to leave them alone. IThe defaults should be fine for most cases, irregardless of other parameters. If you don’t want the prior variance to be learned, set the following values to the same thing: ic_prior_var_min, ic_prior_var_scale, ic_prior_var_max. The prior mean will still be learned. If you really want to limit the information from encoder to decoder, increase ic_post_var_min above 0.

Name Default Description
c_num_steps_for_gen_ic MAXINT Number of steps to train the generator initial condition.
c_ic_dim 64 Dimensionality of the initial conditions.
c_ic_enc_dim 128 Network size for IC encoder.
c_ic_prior_var_min 0.1 Minimum variance of IC prior distribution
c_ic_prior_var_scale 0.1 Variance of IC prior distribution
c_ic_prior_var_max 0.1 Maximum variance of IC prior distribution
c_ic_post_var_min 0.0001 Minimum variance of IC posterior distribution

Generator network, factors, rates

Controlling the size of the generator is one way to control complexity of the dynamics (there is also l2, which will squeeze out unnecessary dynamics also). The modern deep learning approach is to make these cells as large as tolerable (from a waiting perspective), and then regularize them to death with drop out or whatever. It is not clear if this is correct for the LFADS application or not.

Name Default Description
c_cell_weight_scale 1.0 Input scaling for input weights in generator. The combined recurrent and input weights of the encoder and controller cells are by default set to scale at ws/sqrt(#inputs) with ws=1.0. You can change this scaling with this parameter.
c_gen_dim 100 Generator network size/td>
c_gen_cell_input_weight_scale 1.0 Input scaling for input weights in generator, which will be divided by sqrt(#inputs)
c_gen_cell_rec_weight_scale 1.0 Input scaling for recurrent weights in generator.
c_factors_dim 50 Dimensionality of factors read out from generator network. This provides dimensionality reduction from generator dimensionality down to factors and then back out to the neural rates. Note that this property does affect the data and param hashes, unlikely the other c_ prefixed parameters, which only affect the param hash.
c_output_dist 'poisson' Type of output distribution for rates, either 'poisson' or 'gaussian'

Stitching multi-session models

Name Default Description
c_do_train_readin true For stitching models, make the readin matrices trainable (true) or fix them to equal the alignment matrices (false). The per-session readin matrices map from neurons to input factors which are fed into the shared encoder. These are initialized by the alignment matrices and can subsequently be fixed or made trainable.
useAlignmentMatrix false Whether to use an alignment matrix when stitching datasets together./td>
useSingleDatasetAlignmentMatrix false When only using a single dataset, it is also possible to use a readin matrix that reduces the dimensionality of the spikes before inputting these input factors to the encoder networks. If set true, this will set up this readin matrix and seed it with an alignment matrix computed using PCA.
alignmentApproach 'regressGlobalPCs' Algorithm to use when calculating the initial alignment (readin) matrices that map from each session-specific set of neurons to common input factors. Default 'regressGlobalPCs' computes the PCs of all neurons together, and then regresses those PCs on the neurons from each session separately. This mapping produces the best possible linear reconstruction of the global PCs from each session. 'ridgeRegressGlobalPCs' does the same but uses ridge regression (L2 regularization) for more robustness, using a lambda value optimized via cross-validation. This approach may yield better results if some neurons are quite variable or exhibit sparse firing. Additional modes can be implemented; see LFADS.Run/prepareAlignmentMatrices which you may override in your derived `Run` class.
alignmentExtraArgs {} Extra arguments to be passed by LFADS.Run/prepareAlignmentMatrices, either to LFADS.MultisessionAlignmentTool.computeAlignmentMatricesUsingTrialAveragedPCR in the default implementation, or as additional parameters to a custom alignment algorithm.

Posterior sampling

Name Default Description
posterior_mean_kind 'posterior_sample_and_average' Mechanism to obtain the posterior mean. Either 'posterior_sample_and_average' to take a specified number of samples from the posterior distribution, run them through the model, and average the results. Or 'posterior_push_mean' to use the posterior mean of the ICs and inputs and push those through the model directly. Since there are nonlinearities in the network, this need not be equivalent to the mean of the samples, but in practice it's usually pretty close, and is much faster to compute. Note that this parameter does not affect either the param or data hash.
num_samples_posterior 512 Number of samples of the posterior to use when using 'posterior_sample_and_average'. Note that this parameter does not affect either the param or data hash.