Multi-dataset Stitching Models¶
If you specify multiple datasets to be included in an LFADS run by selecting multiple datasets in a RunSpec
, the resulting model will stitch together the multiple datasets. The concept is to generate the spiking data in all of the included datasets using the same encoder and generator RNNs, but to interface to the separate neural datasets through read-in and readout alignment matrices.
Below is a schematic of the readout side. Here, the generator RNN and readout from generator units to factors is the same for all datasets. Therefore, one intends that the factor trajectories would be similar for similar trials / conditions across the datasets. Going from factors to rates, however, the recorded neurons are, in general, not the same across datasets, and the cardinality may differ. Thus, dataset-specific readout matrices are used to combine the factors to produce each of the recorded neurons’ rates on each dataset.
A similar set of dataset specific read-in matrices are used to connect the spiking data to the encoder RNN in order to produce initial conditions and inferred inputs for each trial.
Generating alignment matrices¶
These read-in and readout alignment matrices are learned from the data along with the other parameters. However, it’s useful to seed the alignment matrices with an initial guess that suggests the correspondence between the datasets. If you have multiple datasets in a RunSpec
, and the hyperparameter useAlignmentMatrix
is set to true
in the RunParams
, then lfads-run-manager
will automatically generate read-in alignment matrices from your data using a principal components regression algorithm that proceeds as:
- Generate condition-averaged firing rates for each neuron for each condition for each dataset
- Concatenate all of neurons from all datasets together to build a matrix which is (
nTime * nConditions
)nNeuronsTotal
- Perform PCA on this matrix and keep the projections of the data into the top
nFactors
components. These represent the global shared structure of the data across all datasets. - For each dataset individually, regress these projection scores onto the condition-averaged rates from that dataset alone. The regression coefficients thus transform from that dataset’s neurons to the global shared structure, and consequently, we take this matrix of regression coefficients as the readout matrix.
These matrices will be computed for you automatically by run.prepareForLFADS()
and exported in the LFADS input folder. LFADS will generate an initial guess for the readout alignment matrix, which transforms from common factors back to dataset-specific rates, using the pseudo-inverse of the read-in alignment matrix computed by lfads-run-manager
.
Alignment biases
In addition to this alignment read-in matrix, there is also an alignment bias vector which will be added to each neuron’s counts before projecting through the matrix. Consequently, lfads-run-manager
seeds this bias with the negative mean of the rates of each neuron.
Setting up a multi-session LFADS run¶
Assuming you have finished reading through the single-dataset LFADS walkthrough, you should be all set to generate some LFADS runs and start training. We’ll be setting up another drive script that will do the work of creating the appropriate instances, pointing at the datasets, creating the runs, and telling LFADS Run Manager
to generate the files needed for LFADS. Below, we’ll refer to the package name as LorenzExperiment
, but you should substitute this with your package name.
Follow along with LorenzExperiment.drive_script
A complete drive script is available as a starting point in +LorenzExperiment/drive_script.m
for you to copy/paste from.
For this demo, as before, we’ll generate a few datasets of synthetic spiking data generated by a Lorenz attractor using the following code:
datasetPath = '~/lorenz_example/datasets'; LFADS.Utils.generateDemoDatasets(datasetPath, 'nDatasets', 3);
This will simulate a chaotic 3 dimensional Lorenz attractor as the underlying dynamical system, initialized from 65 initial conditions. The key in these demonstration datasets is that the 65 conditions start from the same initial state and evolve identically across all 3 datasets. Each dataset, however, contains a disjoint set of neurons that are each a different linear recombination of the 3 dimensions of the Lorenz attractor state. This is analogous to the assumption we make in LFADS stitching–each dataset contains different sets of neurons, which are reconstructed from a shared low-dimensional set of factors.
Building a dataset collection and adding datasets¶
First, create a dataset collection that points to a folder on disk where datasets are stored:
dataPath = '~/lorenz_example/datasets'; dc = LorenzExperiment.DatasetCollection(dataPath); dc.name = 'lorenz_example';
Then, we can add the individual datasets within based on their individual paths. Note that when a new dataset instance is created, it is automatically added to the DatasetCollection
and will replace any dataset that has the same name if present.
LorenzExperiment.Dataset(dc, 'dataset001.mat'); LorenzExperiment.Dataset(dc, 'dataset002.mat'); LorenzExperiment.Dataset(dc, 'dataset003.mat');
You can verify that the datasets have been added to the collection:
>> dc LorenzExperiment.DatasetCollection "lorenz_example" 3 datasets in ~/lorenz_example/datasets [ 1] LorenzExperiment.Dataset "dataset001" [ 2] LorenzExperiment.Dataset "dataset002" [ 3] LorenzExperiment.Dataset "dataset003" name: 'lorenz_example' comment: '' path: '~/lorenz_example/datasets' datasets: [3x1 LorenzExperiment.Dataset] nDatasets: 3 datasetNames: {3x1 cell}
You can access individual datasets using dc.datasets(1)
or by name with dc.matchDatasetsByName('dataset001')
.
You can then load all of the metadata for the datasets using:
dc.loadInfo();
How this metadata is determined for each dataset may be customized as described in Interfacing with your Datasets. You can view a summary of the metadata using:
>> dc.getDatasetInfoTable subject date saveTags nTrials nChannels ________________ ______________________ ________ _______ _________ dataset001 'lorenz_example' [31-Jan-2018 00:00:00] '1' 1820 35 dataset002 'lorenz_example' [31-Jan-2018 00:00:00] '1' 1885 26 dataset003 'lorenz_example' [31-Jan-2018 00:00:00] '1' 1365 35
Create a RunCollection
¶
We’ll now setup a RunCollection
that will contain all of the LFADS runs we’ll be training. Inside this folder will be stored all of the processed data and LFADS output, nicely organized within subfolders.
runRoot = '~/lorenz_example/runs'; rc = LorenzExperiment.RunCollection(runRoot, 'exampleStitching', dc); % replace with approximate date script authored as YYYYMMDD % to ensure forwards compatibility rc.version = 20180131;
Specify the hyperparameters in RunParams
¶
We’ll next specify a single set of hyperparameters to begin with. Since this is a simple dataset, we’ll reduce the size of the generator network to 64 and reduce the number of factors to 8. The key change we’ll make is to set useAlignmentMatrix
to true
in order to seed the read-in matrices.
par = LorenzExperiment.RunParams; par.name = 'first_attempt_stitching'; % completely optional par.useAlignmentMatrix = true; % use alignment matrices initial guess for multisession stitching par.spikeBinMs = 2; % rebin the data at 2 ms par.c_co_dim = 0; % no controller --> no inputs to generator par.c_batch_size = 150; % must be < 1/5 of the min trial count par.c_factors_dim = 8; % and manually set it for multisession stitched models par.c_gen_dim = 64; % number of units in generator RNN par.c_ic_enc_dim = 64; % number of units in encoder RNN par.c_learning_rate_stop = 1e-3; % we can stop really early for the demo
We then add this RunParams
to the RunCollection
:
rc.addParams(par);
You can access the parameter settings added to rc
using rc.params
, which will be an array of RunParams
instances.
Specify the RunSpec
set¶
Recall that RunSpec
instances specify which datasets are included in a specific run. For stitching, we’ll want to include all three datasets into a single model.
% include all datasets runSpecName = 'all'; runSpec = LorenzExperiment.RunSpec(runSpecName, dc, 1:dc.nDatasets); rc.addRunSpec(runSpec);
You can adjust the arguments to the constructor of LorenzExperiment.RunSpec
, but in the example provided the inputs define:
- the unique name of the run. Here we use
getSingleRunName
, a convenience method ofDataset
that generates a name likesingle_datasetName
. - the
DatasetCollection
from which datasets will be retrieved - the indices or names of datasets (as a string or cell array of strings) to include
If you like you can also add RunSpecs to train individual models for each dataset as well to facilitate comparison.
% add one run for each single dataset for iR = 1:dc.nDatasets runSpecName = dc.datasets(iR).getSingleRunName(); % 'single_dataset###' runSpec = LorenzExperiment.RunSpec(runSpecName, dc, iR); rc.addRunSpec(runSpec); end
Check the RunCollection
and the Run
¶
The RunCollection
will now display information about the parameter settings and run specifications that have been added. Here there is only one parameter setting by one run specification, so we’re only performing 1 run total.
>> rc LorenzExperiment.RunCollection "exampleStitching" (1 runs total) Dataset Collection "lorenz_example" (3 datasets) in ~/lorenz_example/datasets Path: ~/lorenz_example/runs/exampleStitching 1 parameter settings [1 param_Qr2PeG data_RE1kuL] LorenzExperiment.RunParams "first_attempt_stitching" useAlignmentMatrix=true c_factors_dim=8 c_ic_enc_dim=64 c_gen_dim=64 c_co_dim=0 c_batch_size=150 c_learning_rate_stop=0.001 1 run specifications [ 1] LorenzExperiment.RunSpec "all" (3 datasets) name: 'exampleStitching' comment: '' rootPath: '~/lorenz_example/runs' version: 201801 datasetCollection: [1x1 LorenzExperiment.DatasetCollection] runs: [1x1 LorenzExperiment.Run] params: [1x1 LorenzExperiment.RunParams] runSpecs: [1x1 LorenzExperiment.RunSpec] nParams: 1 nRunSpecs: 1 nRunsTotal: 1 nDatasets: 3 datasetNames: {3x1 cell} path: '~/lorenz_example/runs/exampleStitching' pathsCommonDataForParams: {'~/lorenz_example/runs/exampleStitching/data_RE1kuL'} pathsForParams: {'~/lorenz_example/runs/exampleStitching/param_Qr2PeG'} fileShellScriptTensorboard: '~/lorenz_example/runs/exampleStitching/launch_tensorboard.sh' fileSummaryText: '~/lorenz_example/runs/exampleStitching/summary.txt' fileShellScriptRunQueue: '~/lorenz_example/runs/exampleStitching/run_lfadsqueue.py' >> run = rc.findRuns('all', 1); LorenzExperiment.Run "all" (3 datasets) Path: ~/lorenz_example/runs/exampleStitching/param_Qr2PeG/all Data: ~/lorenz_example/runs/exampleStitching/data_RE1kuL LorenzExperiment.RunParams "first_attempt_stitching" : useAlignmentMatrix=true c_factors_dim=8 c_ic_enc_dim=64 c_gen_dim=64 c_co_dim=0 c_batch_size=150 c_learning_rate_stop=0.001 3 datasets in "lorenz_example" [ 1] LorenzExperiment.Dataset "dataset001" [ 2] LorenzExperiment.Dataset "dataset002" [ 3] LorenzExperiment.Dataset "dataset003" ...
Verifying the alignment matrices¶
Next, we’ll run the principal components regression that generates the alignment matrices using the algorithm described above. Then we’ll verify that these matrices are able to project the data from each dataset into similar looking low-dimensional trajectories.
To visualize how well these initial alignment matrices are working, we can compare the common global PCs from all datasets against the projection of each dataset through the read-in matrices. That is, we can plot the regression target (global PCs) against the best possible reconstruction from each dataset.
run.doMultisessionAlignment();
Under the hood, the alignment matrix calculations are performed by an instance of LFADS.MutlisessionAlignmentTool
. To plot the reconstruction quality, you can call tool.plotAlignmentReconstruction(numberOrIndicesOfFactorsToPlot, numberOrIndicesOfConditionsToPlot)
, like so:
tool = run.multisessionAlignmentTool; nFactorsPlot = 3; conditionsToPlot = [1 20 40]; tool.plotAlignmentReconstruction(nFactorsPlot, conditionsToPlot);
In this example, the single-dataset predictions look quite similar to the global target, especially in the first 2 principal components which capture most of the variance.
The actual alignment matrices can be accessed using:
tool.alignmentMatrices % nDatasets x 1 cell array of read-in matrices
Prepare for LFADS¶
Now that you’ve set up your run collection with all of your runs, you can run the following to generate the files needed for running LFADS.
rc.prepareForLFADS();
This will generate files for all runs. If you decide to add new runs, by adding additional run specifications or parameters, you can simply call prepareForLFADS
again. Existing files won’t be overwritten unless you call rc.prepareForLFADS(true)
.
After running prepareForLFADS
, the run manager will create the following files on disk under rc.path
:
~/lorenz_example/runs/exampleStitching ├── data_4MaTKO │ ├── inputInfo_dataset001.mat │ ├── inputInfo_dataset002.mat │ ├── inputInfo_dataset003.mat │ ├── lfads_dataset001.h5 │ ├── lfads_dataset002.h5 │ └── lfads_dataset003.h5 ├── param_YOs74u │ └── all │ └── lfadsInput │ ├── inputInfo_dataset001.mat -> ../../../data_4MaTKO/inputInfo_dataset001.mat │ ├── inputInfo_dataset002.mat -> ../../../data_4MaTKO/inputInfo_dataset002.mat │ ├── inputInfo_dataset003.mat -> ../../../data_4MaTKO/inputInfo_dataset003.mat │ ├── lfads_dataset001.h5 -> ../../../data_4MaTKO/lfads_dataset001.h5 │ ├── lfads_dataset002.h5 -> ../../../data_4MaTKO/lfads_dataset002.h5 │ └── lfads_dataset003.h5 -> ../../../data_4MaTKO/lfads_dataset003.h5 └── summary.txt
The organization of these files on disk is discussed in more detail here. Also, a summary.txt
file will be generated which can be useful for identifying all of the runs and their locations on disk. You can also generate this text from within Matlab by calling rc.generateSummaryText()
.
LorenzExperiment.RunCollection "exampleStitching2" (1 runs total) Path: ~/lorenz_example/runs/exampleStitching2 Dataset Collection "lorenz_example" (3 datasets) in ~/lorenz_example/datasets ------------------------ 1 Run Specifications: [runSpec 1] LorenzExperiment.RunSpec "all" (3 datasets) [ds 1] LorenzExperiment.Dataset "dataset001" [ds 2] LorenzExperiment.Dataset "dataset002" [ds 3] LorenzExperiment.Dataset "dataset003" ------------------------ 1 Parameter Settings: [1 param_Qr2PeG data_RE1kuL] LorenzExperiment.RunParams "first_attempt_stitching" useAlignmentMatrix=true c_factors_dim=8 c_ic_enc_dim=64 c_gen_dim=64 c_co_dim=0 c_batch_size=150 c_learning_rate_stop=0.001 spikeBinMs: 2 trainToTestRatio: 4 useAlignmentMatrix: true useSingleDatasetAlignmentMatrix: false scaleIncreaseStepsWithDatasets: true c_cell_clip_value: 5 c_factors_dim: 8 c_ic_enc_dim: 64 c_ci_enc_dim: 128 c_gen_dim: 64 c_keep_prob: 0.95 c_learning_rate_decay_factor: 0.98 c_device: /gpu:0 c_co_dim: 0 c_do_causal_controller: false c_do_feed_factors_to_controller: true c_feedback_factors_or_rates: factors c_controller_input_lag: 1 c_do_train_readin: true c_l2_gen_scale: 500 c_l2_con_scale: 500 c_batch_size: 150 c_kl_increase_steps: 900 c_l2_increase_steps: 900 c_ic_dim: 64 c_con_dim: 128 c_learning_rate_stop: 0.001 c_temporal_spike_jitter_width: 0 c_allow_gpu_growth: true c_kl_ic_weight: 1 c_kl_co_weight: 1 c_inject_ext_input_to_gen: false c_prior_ar_atau: 10 c_do_train_prior_ar_atau: true c_prior_ar_nvar: 0.1 c_do_train_prior_ar_nvar: true num_samples_posterior: 512 posterior_mean_kind: posterior_sample_and_average
After running prepareForLFADS
, you can then run the LFADS model or models in the same way as with single-session models, using the instructions here.