LFADS Hyperparameters¶
The following is an nearly exhaustive list of hyperparameters that affect the training and posterior mean sampling of the model. If the parameter name begins with c_
, this parameter will be passed to the LFADS Python code directly, without the c_
prefix. All of these values are specified in the RunParams
instance that accompanies each Run
.
While there are many parameters, you will likely care about only a small subset of them. We have color-coded the hyperparameters below according to how often they require tuning:
Tuning Frequency | Description |
---|---|
Common | Typically requires tuning on a per-project basis and/or is important to set appropriately upfront. |
Occasional | Might be adjusted for fine tuning. |
Rare | Infrequently requires tuning and/or primarily intended for advanced users. |
Run Manager logistics and data Processing¶
Name | Default | Description |
---|---|---|
name | '' |
Name of this set of parameters, used for convenience only. Note that this parameter does not affect either the param or data hash. |
version | n/a | This value you should not assign directly, as it will automatically be set to match the version of the RunCollection to which it is added. This is used for graceful backwards compatibility. Note that this parameter does not affect either the param or data hash. |
spikeBinMs | 2 | Spike bin width in milliseconds. This must be an integer multiple of the original bin width provided by the `Run` class by `generateCountsForDataset`. |
TensorFlow logistics¶
Name | Default | Description |
---|---|---|
c_allow_gpu_growth | true | Allow the GPU to dynamically allocate memory instead of allocating all the GPU's memory at the start |
c_max_ckpt_to_keep | 5 | Max number of checkpoints to keep (rolling) |
c_max_ckpt_to_keep_lve | 5 | Max number of checkpoints to keep for lowest validation error models (rolling) |
c_device | 'gpu:0' |
Which visible GPU or CPU to use. Note that GPUs are typically scheduled by setting `CUDA_VISIBLE_DEVICES` rather than using this parameter. |
Optimization¶
Rather put the learning rate on an exponentially decreasing schedule, the current algorithm pays attention to the learning rate, and if it isn’t regularly decreasing, it will decrease the learning rate. So far, it works fine, though it is not perfect.
Name | Default | Description |
---|---|---|
c_learning_rate_init | 0.01 | Initial learning rate |
c_learning_rate_decay_factor | 0.95 | Factor by which to decrease the learning rate if progress isn't being made. |
c_learning_rate_n_to_compare | 6 | Number of previous costs current cost has to be worse than, in order to lower learning rate. |
c_learning_rate_stop | 0.00001 | Stop training when the learning rate reaches this threshold. |
c_max_grad_norm | 200 | Max norm of gradient before clipping. This sets a value, above which, the gradients will be clipped. This hp is extremely useful to avoid an infrequent, but highly pathological problem whereby the gradient is so large that it destroys the optimization by setting parameters too large, leading to a vicious cycle that ends in NaNs. If it's too large, it's useless, if it's too small, it essentially becomes the learning rate. It's pretty insensitive, though. |
trainToTestRatio | 4 | Ratio of training vs testing trials used. |
c_batch_size | 256 | Number of trials to use during each training pass. The total trial count must be ≥ c_batch_size * (trainToTestRatio + 1) . |
c_cell_clip_value | 5 | Max value recurrent cell can take before being clipped. If your optimizations start "NaN-ing out", reduce this value so that the values of the network don't grow out of control. Typically, once this parameter is set to a reasonable value, one stops having numerical problems. |
Overfitting¶
If controller is heavily penalized, then it won’t have any output. If dynamics are heavily penalized, then generator won’t make dynamics. Note this l2 penalty is only on the recurrent portion of the RNNs, as dropout is also available, penalizing the feed-forward connections.
Name | Default | Description |
---|---|---|
c_temporal_spike_jitter_width | 0 | Enables jittering spike times during training. It appears that the system will happily fit spikes (blessing or curse, depending). You may not want this. Jittering the spikes a bit may help (-/+ bin size, as specified here). The idea is to prevent LFADS from trying to learn very fine temporal structure in the data if you believe this to be noise. |
c_keep_prob | 0.95 | Fraction of units to randomly drop during each training pass. Dropout is done on the input data, on controller inputs (from encoder), and on outputs from generator to factors. |
c_l2_gen_scale | 500 | L2 regularization cost for the generator only. |
c_co_mean_corr_scale | 0 | Cost of correlation (through time) in the means of controller output. |
Underfitting¶
If the primary task of LFADS is “filtering” of data and not generation, then it is possible that the KL penalty is too strong. Empirically, we have found this to be the case. So we add a hyperparameter in front of the the two KL terms (one for the initial conditions to the generator, the other for the controller outputs). You should always think of the the default values as 1.0, and that leads to a standard VAE formulation whereby the numbers that are optimized are a lower-bound on the log-likelihood of the data. When these 2 HPs deviate from 1.0, one cannot make any statement about what those LL lower bounds mean anymore, and they cannot be compared (AFAIK).
Sometimes the task can be sufficiently hard to learn that the optimizer takes the ‘easy route’, and simply minimizes the KL divergence, setting it to near zero, and the optimization gets stuck. The same possibility is true for the L2 regularizer. One wants a simple generator, for scientific reasons, but not at the expense of hosing the optimization. The last 5 parameters help avoid that by by getting the optimization to ‘latch’ on to the main optimization, and only turning on the regularizers gradually by increasing their weighting in the overall cost functions later.
Name | Default | Description |
---|---|---|
c_kl_ic_weight | 1 | Strength of KL weight on initial conditions KL penalty. |
c_kl_co_weight | 1 | Strength of KL weight on controller output KL penalty. |
c_kl_start_step | 0 | Start increasing KL weight after this many steps. |
c_kl_increase_steps | 900 | Number of steps over which the KL weight increases. |
c_l2_start_step | 0 | Start increasing L2 weight after this many steps. |
c_l2_increase_steps | 900 | Number of steps over which the L2 weight increases. |
c_l2_start_step | 0 | Start increasing L2 weight after this many steps |
scaleIncreaseStepsWithDatasets | true | If true, c_kl_increase_steps and c_l2_increase_steps will be multiplied by the number of datasets in a stitching run. |
External inputs¶
If there are observed inputs, there are two ways to add that observed input to the model. The first is by treating as something to be inferred, and thus encoding the observed input via the encoders, and then input to the generator via the “inferred inputs” channel. Second, one can input the input directly into the generator. This has the downside of making the generation process strictly dependent on knowing the observed input for any generated trial.
Name | Default | Description |
---|---|---|
c_ext_input_dim | 0 | Number of external, known (or observed) inputs. |
c_inject_ext_input_to_gen | false | Should the known inputs be input to model via encoders (false) or injected directly into generator (true)? |
Controller and inferred inputs¶
The controller will be more powerful if it can see the encoding of the entire
trial. However, this allows the controller to create inferred inputs that are
acausal with respect to the actual data generation process. For example, the data
generator could have an input at time t
, but the controller, after seeing the
entirety of the trial could infer that the input is coming a little before
time t
, because there are no restrictions on the data the controller sees.
One can force the controller to be causal (with respect to perturbations in
the data generator) so that it only sees forward encodings of the data at time
t
that originate at times before or at time t
. One can also control the data
the controller sees by using an input lag (forward encoding at time t-tlag
for controller input at time t
. The same can be done in the reverse direction
(controller input at time t
from reverse encoding at time t+tlag
, in the
case of an acausal controller). Setting this lag > 0 (even lag=1) can be a
powerful way of avoiding very spiky decodes. Finally, one can manually control
whether the factors at time t-1
are fed to the controller at time t
.
If you don’t care about any of this, and just want to smooth your data, set
do_causal_controller = False
, do_feed_factors_to_controller = True
, controller_input_lag = 0
.
Name | Default | Description |
---|---|---|
c_co_dim | 4 | Number of inferred inputs (controller outputs). This parameter critically controls whether or not there is a controller (along with controller encoders placed into the LFADS graph. If equal to 0, no controller will be added. |
c_prior_ar_atau | 10 | Initial autocorrelation of AR(1) priors (in time bins) |
c_do_train_prior_ar_atau | true | Is the value for atau an initial value (true) or the constant value (false)? |
c_prior_ar_nvar | 0.1 | Initial noise variance for AR(1) priors |
c_do_train_prior_ar_nvar | true | Is the value for the noise var an initial value (true) or the constant value (false)? |
c_do_causal_controller | false | Restrict input encoder from seeing the future? |
c_do_feed_factors_to_controller | true | Should factors[t-1] be input to controller at time t? Strictly speaking, feeding either the factors or the rates to the controller violates causality, since the g0 gets to see all the data. This may or may not be only a theoretical concern. |
c_feedback_factors_or_rates | 'factors' |
Feedback the factors or the rates to the controller? Set to either 'factors' or 'rates' |
c_controller_input_lag | 1 | Time lag on the encoding to controller t-lag for forward, t+lag for reverse. |
c_ci_enc_dim | 128 | Network size for controller input encoder. |
c_con_dim | 128 | Controller dimensionality. |
c_co_prior_var_scale | 0.1 | Variance of control input prior distribution. |
Encoder and initial conditions for generator¶
Note that the dimension of the initial conditions is separated from the dimensions of the generator initial conditions (and a linear matrix will adapt the shapes if necessary). This is just another way to control complexity. In all likelihood, setting the IC dims to the size of the generator hidden state is just fine.
For the initial condition prior variance parameters, it’s best to leave them alone.
IThe defaults should be fine for most cases, irregardless of other parameters.
If you don’t want the prior variance to be learned, set the
following values to the same thing: ic_prior_var_min
, ic_prior_var_scale
, ic_prior_var_max
.
The prior mean will still be learned. If you really want to limit the information from encoder to decoder,
increase ic_post_var_min
above 0.
Name | Default | Description |
---|---|---|
c_num_steps_for_gen_ic | MAXINT |
Number of steps to train the generator initial condition. |
c_ic_dim | 64 | Dimensionality of the initial conditions. |
c_ic_enc_dim | 128 | Network size for IC encoder. |
c_ic_prior_var_min | 0.1 | Minimum variance of IC prior distribution |
c_ic_prior_var_scale | 0.1 | Variance of IC prior distribution |
c_ic_prior_var_max | 0.1 | Maximum variance of IC prior distribution |
c_ic_post_var_min | 0.0001 | Minimum variance of IC posterior distribution |
Generator network, factors, rates¶
Controlling the size of the generator is one way to control complexity of the dynamics (there is also l2, which will squeeze out unnecessary dynamics also). The modern deep learning approach is to make these cells as large as tolerable (from a waiting perspective), and then regularize them to death with drop out or whatever. It is not clear if this is correct for the LFADS application or not.
Name | Default | Description |
---|---|---|
c_cell_weight_scale | 1.0 | Input scaling for input weights in generator. The combined recurrent and input weights of the encoder and controller cells are by default set to scale at ws/sqrt(#inputs) with ws=1.0. You can change this scaling with this parameter. |
c_gen_dim | 100 | Generator network size/td> |
c_gen_cell_input_weight_scale | 1.0 | Input scaling for input weights in generator, which will be divided by sqrt(#inputs) |
c_gen_cell_rec_weight_scale | 1.0 | Input scaling for recurrent weights in generator. |
c_factors_dim | 50 | Dimensionality of factors read out from generator network. This provides dimensionality reduction from generator dimensionality down to factors and then back out to the neural rates. Note that this property does affect the data and param hashes, unlikely the other c_ prefixed parameters, which only affect the param hash. |
c_output_dist | 'poisson' | Type of output distribution for rates, either 'poisson' or 'gaussian' |
Stitching multi-session models¶
Name | Default | Description |
---|---|---|
c_do_train_readin | true | For stitching models, make the readin matrices trainable (true) or fix them to equal the alignment matrices (false). The per-session readin matrices map from neurons to input factors which are fed into the shared encoder. These are initialized by the alignment matrices and can subsequently be fixed or made trainable. |
useAlignmentMatrix | false | Whether to use an alignment matrix when stitching datasets together./td> |
useSingleDatasetAlignmentMatrix | false | When only using a single dataset, it is also possible to use a readin matrix that reduces the dimensionality of the spikes before inputting these input factors to the encoder networks. If set true, this will set up this readin matrix and seed it with an alignment matrix computed using PCA. |
alignmentApproach | 'regressGlobalPCs' |
Algorithm to use when calculating the initial alignment (readin) matrices that map from each session-specific set of neurons to common input factors. Default 'regressGlobalPCs' computes the PCs of all neurons together, and then regresses those PCs on the neurons from each session separately. This mapping produces the best possible linear reconstruction of the global PCs from each session. 'ridgeRegressGlobalPCs' does the same but uses ridge regression (L2 regularization) for more robustness, using a lambda value optimized via cross-validation. This approach may yield better results if some neurons are quite variable or exhibit sparse firing. Additional modes can be implemented; see LFADS.Run/prepareAlignmentMatrices which you may override in your derived `Run` class. |
alignmentExtraArgs | {} |
Extra arguments to be passed by LFADS.Run/prepareAlignmentMatrices , either to LFADS.MultisessionAlignmentTool.computeAlignmentMatricesUsingTrialAveragedPCR in the default implementation, or as additional parameters to a custom alignment algorithm. |
Posterior sampling¶
Name | Default | Description |
---|---|---|
posterior_mean_kind | 'posterior_sample_and_average' |
Mechanism to obtain the posterior mean. Either 'posterior_sample_and_average' to take a specified number of samples from the posterior distribution, run them through the model, and average the results. Or 'posterior_push_mean' to use the posterior mean of the ICs and inputs and push those through the model directly. Since there are nonlinearities in the network, this need not be equivalent to the mean of the samples, but in practice it's usually pretty close, and is much faster to compute. Note that this parameter does not affect either the param or data hash. |
num_samples_posterior | 512 | Number of samples of the posterior to use when using 'posterior_sample_and_average' . Note that this parameter does not affect either the param or data hash. |