topomics.models.MultimodalAmortizedLDA#
- class topomics.models.MultimodalAmortizedLDA(adata, n_inputs_modalities, likelihoods, n_topics=20, n_hidden=128, cell_topic_prior=None, topic_feature_prior=None, topic_feature_prior_type='logistic_normal', use_feature_background=True, dispersion_rna=1.0, learnable_dispersion=False, global_dispersion=True, modality_names=None, weight_mode='cell', likelihood_weight_mode='none', likelihood_weight_ref='mean', gcn_n_layers=1, gcn_n_pre_layers=0, gcn_conv_type='GATv2Conv', gcn_hidden_dims=None, gcn_alpha_init=0.7, gcn_use_learned_alpha=True, normalize_encoder_inputs=True, encoder_scale_factor=1000000.0, entropy_weight=0.01, topic_variance_weight=1.0, kl_weight=1.0, encode_covariates=True, bg_offset=1e-15, learnable_bg=True, aggregation_type='moe', att_dim=32, spatial_mode='gcn', sgc_n_layers=1, gcn_sampling='approx', gcn_fan_out=None, gcn_conv_first=False)#
Multimodal Amortized LDA with Mixture-of-Experts (MoE)
Extends
scvi.model.AmortizedLDAto M modalities with modality-specific encoders and likelihoods. Each modality is encoded separately, and representations are mixed via weighted Gaussian combination before inferring the shared cell-topic distribution θₙ.- Parameters:
adata (
AnnData) –AnnDatawith concatenated features (RNA + protein + …).n_inputs_modalities (
list[int]) – List with feature counts per modality, in the order they appear inadata.X.likelihoods (
list[str]) – Length-matched list of likelihood strings for each modality.n_topics (
int(default:20)) – Number of topics (K).n_hidden (
int(default:128)) – Hidden units of each encoder network.cell_topic_prior (
float|Sequence[float] |None(default:None)) – Dirichlet concentration for θₙ.None⇒ symmetric 1/K.topic_feature_prior (
float|Sequence[float] |None(default:None)) – Dirichlet concentration for each ϕₖ,ₘ.None⇒ symmetric 1/K.weight_mode (
str(default:'cell')) – How to weight modality-specific representations: -"equal": All modalities weighted equally (default) -"universal": Learn a single weight per modality -"cell": Learn per-cell, per-modality weights
Notes
The Mixture-of-Experts architecture processes each modality through a separate encoder network, then combines their latent representations using learned or fixed weights.
Attributes table#
Data attached to model instance. |
|
Manager instance associated with self.adata. |
|
The current device that the module's params are on. |
|
What the get normalized functions name is |
|
Returns computed metrics during training. |
|
Whether the model has been trained. |
|
Data attached to model instance. |
|
Returns the run id of the model. |
|
Returns the run name of the model. |
|
Summary string of the model. |
|
Observations that are in test set. |
|
Observations that are in train set. |
|
Observations that are in validation set. |
Methods table#
|
Validate and process the input data. |
Standardize and validate modality keys in data_dict. |
|
Clear the cached metrics. |
|
|
Converts a legacy saved model (<v0.15.0) to the updated save format. |
|
Compute SHARE-Topic–style cross-modal interaction matrix P_{a,b} |
|
Returns the object in AnnData associated with the key in the data registry. |
|
Deregisters the |
|
Not implemented for this model class. |
|
Fit the model to the provided data. |
|
Convenience constructor: setup + instantiation in one call. |
|
High-level constructor for multimodal AmortizedLDA from MuData. |
|
Retrieves the |
|
Compute per-cell entropy of cell-topic distributions. |
|
Get the cell-topic matrix Θ (C × K). |
|
Average ELBO across batches (higher is better). |
|
Compute mean entropy of cell-topic distributions. |
Get the entropy regularization weight. |
|
|
Monte-Carlo estimate of E[ϕₖ,ₘ]. |
|
Returns the object in AnnData associated with the key in the data registry. |
Get the mean entropy from the last forward pass through the model. |
|
Get the last computed mean topic variance from training. |
|
|
Infer θₙ for all cells (or subset). |
|
Get the learned or fixed dispersion parameters. |
|
Compute log-likelihood for each modality separately. |
|
Return per-modality likelihood scaling weights used in the generative model. |
|
Get normalized mixing weights showing how much each modality contributes to topic assignments. |
|
Not implemented for this model class. |
|
exp( -log_likelihood / total_counts ) — lower is better. |
|
Compute perplexity for each modality separately. |
|
Returns the string provided to setup of a specific setup_arg. |
|
Returns the state registry for the AnnDataField registered with this instance. |
|
Get top N features for each topic in a specific modality. |
|
Compute topic diversity (average pairwise cosine distance) per modality or overall. |
|
Compute per-topic variance of topic usage across cells. |
Get the topic variance regularization weight. |
|
|
Variable names of input data. |
|
Instantiate a model from the saved output. |
|
Return the full registry saved with the model. |
|
Predict using the fitted model on the provided data. |
|
Registers an |
|
Save the state of the model. |
|
Setup method for dict[str, AnnData] input. |
|
Sets up the |
|
Universal setup method with automatic type detection. |
|
Setup MuData for multimodal AmortizedLDA. |
|
Setup method for SpatialData input. |
|
Move the model to the device. |
|
Override to default to running validation when a split is requested. |
|
Transfer fields from a model to an AnnData object. |
|
Update setup method args. |
|
Print summary of the setup for the initial AnnData or a given AnnData object. |
|
Prints summary of the registry. |
|
Print args used to setup a saved model. |
Prints setup kwargs used to produce a given registry. |
Attributes#
- MultimodalAmortizedLDA.adata#
Data attached to model instance.
- MultimodalAmortizedLDA.adata_manager#
Manager instance associated with self.adata.
- MultimodalAmortizedLDA.device#
The current device that the module’s params are on.
- MultimodalAmortizedLDA.get_normalized_function_name#
What the get normalized functions name is
- MultimodalAmortizedLDA.history#
Returns computed metrics during training.
- MultimodalAmortizedLDA.is_trained#
Whether the model has been trained.
- MultimodalAmortizedLDA.registry#
Data attached to model instance.
- MultimodalAmortizedLDA.run_id#
Returns the run id of the model. Used in MLFlow
- MultimodalAmortizedLDA.run_name#
Returns the run name of the model. Used in MLFlow
- MultimodalAmortizedLDA.summary_string#
Summary string of the model.
- MultimodalAmortizedLDA.test_indices#
Observations that are in test set.
- MultimodalAmortizedLDA.train_indices#
Observations that are in train set.
- MultimodalAmortizedLDA.validation_indices#
Observations that are in validation set.
Methods#
- MultimodalAmortizedLDA.check_input(mdata, modalities)#
Validate and process the input data.
Checks that data are adata or mudata objects, and that the modalities are correctly specified.
- MultimodalAmortizedLDA.check_modalities_names()#
Standardize and validate modality keys in data_dict.
Maps various synonyms to ‘rna’, ‘protein’, or ‘chromatin’, and rebuilds data_dict with standardized keys.
- MultimodalAmortizedLDA.clear_metric_cache()#
Clear the cached metrics.
Call this method after retraining the model to ensure metrics are recomputed with the updated parameters.
- classmethod MultimodalAmortizedLDA.convert_legacy_save(dir_path, output_dir_path, overwrite=False, prefix=None, **save_kwargs)#
Converts a legacy saved model (<v0.15.0) to the updated save format.
- Parameters:
dir_path (
str) – Path to the directory where the legacy model is saved.output_dir_path (
str) – Path to save converted save files.overwrite (
bool(default:False)) – Overwrite existing data or not. IfFalseand directory already exists atoutput_dir_path, an error will be raised.prefix (
str|None(default:None)) – Prefix of saved file names.**save_kwargs – Keyword arguments passed into
save().
- Return type:
- MultimodalAmortizedLDA.cross_modality_score(mod_a, mod_b, *, normalise=True, return_df=True)#
Compute SHARE-Topic–style cross-modal interaction matrix P_{a,b}
- Parameters:
- Return type:
- Returns:
P : shape (n_feat_a, n_feat_b) – interaction score between every feature of
mod_aand every feature ofmod_b
- MultimodalAmortizedLDA.data_registry(registry_key)#
Returns the object in AnnData associated with the key in the data registry.
- MultimodalAmortizedLDA.deregister_manager(adata=None)#
Deregisters the
AnnDataManagerinstance associated withadata.If
adataisNone, deregisters allAnnDataManagerinstances in both the class and instance-specific manager stores, except for the one associated with this model instance.
- MultimodalAmortizedLDA.differential_abundance(*args, **kwargs)#
Not implemented for this model class.
Available in models that inherit from
VAEMixin.- Raises:
- MultimodalAmortizedLDA.fit(data)#
Fit the model to the provided data.
- Parameters:
data (The input data to fit the model.)
- classmethod MultimodalAmortizedLDA.from_data(data, modalities=None, layers=None, spatial_keys=None, table_key='table', categorical_covariate_keys=None, continuous_covariate_keys=None, **model_kwargs)#
Convenience constructor: setup + instantiation in one call.
This is equivalent to calling setup_data() followed by the model constructor. Automatically detects input type and handles all preprocessing.
- Parameters:
data – Input data - can be: - AnnData: Single or concatenated modalities - MuData: Multiple modalities - SpatialData: Spatial omics data - dict[str, AnnData]: Dictionary mapping modality names to AnnData
modalities (
list[str] |None(default:None)) – Modalities to use. For MuData/SpatialData, subset selection. For AnnData, provide single modality name (default: “rna”). For dict, uses dict keys if None.layers (
dict[str,str|None] |str|None(default:None)) – Layer specification: - dict: Per-modality layers {“rna”: “counts”, “protein”: None} - str: Same layer for all modalities - None: Use .X for allspatial_keys (
dict[str,str] |str|None(default:None)) – Spatial graph keys in .obsp: - dict: Per-modality spatial keys - str: Same key for all modalities - None: No spatial datatable_key (
str(default:'table')) – For SpatialData only: which table to use (default: “table”)**model_kwargs – Additional arguments passed to model constructor: - n_topics: Number of topics (required) - n_hidden: Hidden units (default: 128) - weight_mode: “equal”, “universal”, or “cell” (default: “equal”) - likelihood_weight_mode: “none”, “inverse_features”, “sqrt_inverse_features” (default: “none”) - likelihood_weight_ref: “mean”, “median”, or “max” (default: “mean”) - gcn_n_layers: Number of graph conv layers for spatial encoders (default: 1) - gcn_hidden_dims: List of hidden sizes per graph conv layer - likelihoods: List of likelihoods per modality (“multinomial”, “gamma_poisson”/”nb”, “bernoulli”; auto-inferred if not provided)
- Returns:
MultimodalAmortizedLDA Initialized model ready for training
Examples
>>> # From MuData with layer selection >>> model = MultimodalAmortizedLDA.from_data( ... mdata, modalities=["rna", "protein"], layers={"rna": "counts"}, n_topics=20 ... )
>>> # From SpatialData >>> model = MultimodalAmortizedLDA.from_data( ... sdata, table_key="table", layers="counts", spatial_keys="spatial_connectivities", n_topics=20 ... )
>>> # From dict of AnnData >>> model = MultimodalAmortizedLDA.from_data( ... {"rna": adata_rna, "protein": adata_protein}, layers={"rna": "counts"}, n_topics=20 ... )
>>> # From single AnnData >>> model = MultimodalAmortizedLDA.from_data(adata, modalities=["rna"], layers="counts", n_topics=20)
- classmethod MultimodalAmortizedLDA.from_mudata(mdata, modality_order=None, layer_dict=None, spatial_key=None, spatial_modality_keys=None, categorical_covariate_keys=None, continuous_covariate_keys=None, **model_kwargs)#
High-level constructor for multimodal AmortizedLDA from MuData.
- Parameters:
mdata (
MuData) – MuData object containing multiple modalities.modality_order (
list[str] |None(default:None)) – Order of modalities to use. If None, uses all modalities in mdata.mod.keys().layer_dict (
dict[str,str] |None(default:None)) – Dictionary mapping modality names to layer names to use for each modality.spatial_key (
str|None(default:None)) – Single obsp key applied to all modalities (if spatial_modality_keys is not provided).spatial_modality_keys (
dict[str,str] |None(default:None)) – Mapping of modality -> obsp key for modality-specific spatial graphs.**model_kwargs – Additional arguments passed to the model constructor. Common arguments include: - n_topics: Number of topics (default: 20) - n_hidden: Hidden units in encoders (default: 128) - weight_mode: “equal”, “universal”, or “cell” (default: “equal”) - likelihood_weight_mode: “none”, “inverse_features”, “sqrt_inverse_features” (default: “none”) - likelihood_weight_ref: “mean”, “median”, or “max” (default: “mean”) - gcn_n_layers: Number of graph conv layers for spatial encoders (default: 1) - gcn_hidden_dims: List of hidden sizes per graph conv layer - likelihoods: List of likelihoods per modality (“multinomial”, “gamma_poisson”/”nb”, “bernoulli”; auto-inferred if not provided)
- MultimodalAmortizedLDA.get_anndata_manager(adata, required=False)#
Retrieves the
AnnDataManagerfor a given AnnData object.Requires
self.idhas been set. Checks for anAnnDataManagerspecific to this model instance.- Parameters:
- Return type:
- MultimodalAmortizedLDA.get_cell_entropy(adata=None, indices=None, batch_size=None, n_samples=100)#
Compute per-cell entropy of cell-topic distributions.
- Parameters:
adata (
AnnData|None(default:None)) – AnnData object with data. If None, uses the training data.indices (
Sequence[int] |None(default:None)) – Indices of cells to compute entropy for. If None, uses all cells.batch_size (
int|None(default:None)) – Batch size for computation. If None, processes all cells at once.n_samples (
int(default:100)) – Number of posterior samples for Monte Carlo estimation (default: 100)
- Return type:
- Returns:
np.ndarray Per-cell entropy values, shape (n_cells,) H(θ_n) = -Σ_k θ_n,k * log(θ_n,k) for each cell n
- MultimodalAmortizedLDA.get_cell_topic_dist(adata=None, indices=None, batch_size=None, n_samples=5000, return_dataframe=False)#
Get the cell-topic matrix Θ (C × K).
- Parameters:
adata (
AnnData|None(default:None)) – AnnData object to use (default: self.adata).indices (
Sequence[int] |None(default:None)) – Subset of cells to use.batch_size (
int|None(default:None)) – Batch size for inference.n_samples (
int(default:5000)) – Number of samples for Monte Carlo estimation.return_dataframe (
bool(default:False)) – IfTrue, return apd.DataFrame(cells × topics) indexed byadata.obs_nameswithtopic_kcolumns instead of a raw array.
- Return type:
- Returns:
Θ : np.ndarray | pd.DataFrame Cell-topic matrix, where C is the number of cells and K is the number of topics.
- MultimodalAmortizedLDA.get_elbo(adata=None, indices=None, batch_size=None)#
Average ELBO across batches (higher is better).
Note: Pyro’s Trace_ELBO.loss() returns -ELBO; this method negates it so the returned value is the actual ELBO (higher = better fit).
- Return type:
- MultimodalAmortizedLDA.get_entropy(adata=None, indices=None, batch_size=None, normalised=True)#
Compute mean entropy of cell-topic distributions.
Higher entropy means topics are more evenly distributed across cells.
- Parameters:
adata (
AnnData|None(default:None)) – AnnData object to use (default: self.adata).indices (
Sequence[int] |None(default:None)) – Subset of cells to use.batch_size (
int|None(default:None)) – Batch size for inference.normalised (
bool(default:True)) – Whether to normalize cell-topic distributions before computing entropy.
- Return type:
- Returns:
float Mean entropy across cells
- MultimodalAmortizedLDA.get_entropy_weight()#
Get the entropy regularization weight.
- Return type:
- Returns:
float Current entropy_weight value used for regularization
- MultimodalAmortizedLDA.get_feature_topic_dist(modality=None, n_samples=5000, as_dict=False)#
Monte-Carlo estimate of E[ϕₖ,ₘ].
- Parameters:
modality (
str|int|None(default:None)) – Modality name or index. If provided, return only that modality’s topic-feature distribution; otherwise return all.n_samples (
int(default:5000)) – MC samples from variational posterior.as_dict (
bool(default:False)) – If True, return{m: DataFrame}per modality; otherwise concatenate along features (like original single-modality API).
- Return type:
- Returns:
dict of DataFrames (default) – index = feature names, columns = topics
or a single concatenated DataFrame if
as_dict=False.
- MultimodalAmortizedLDA.get_from_registry(adata, registry_key)#
Returns the object in AnnData associated with the key in the data registry.
AnnData object should be registered with the model prior to calling this function via the
self._validate_anndatamethod.
- MultimodalAmortizedLDA.get_last_entropy()#
Get the mean entropy from the last forward pass through the model.
- MultimodalAmortizedLDA.get_last_topic_variance()#
Get the last computed mean topic variance from training.
- MultimodalAmortizedLDA.get_latent_representation(adata=None, indices=None, batch_size=None, n_samples=5000, return_dataframe=False)#
Infer θₙ for all cells (or subset).
Alias for
get_cell_topic_dist(). Returns the cell-topic matrix Θ (C × K) of softmax-normalized expectations as anp.ndarrayby default, or apd.DataFrameifreturn_dataframe=True.
- MultimodalAmortizedLDA.get_learned_dispersion(modality=None, n_samples=1000)#
Get the learned or fixed dispersion parameters.
- Parameters:
- Return type:
- Returns:
dict[str, np.ndarray] | np.ndarray If modality is None: dict mapping modality names to dispersion arrays If modality specified: dispersion array for that modality
Notes
If learnable_dispersion=False, returns the fixed dispersion value
If learnable_dispersion=True and global_dispersion=True: returns (1,) array
If learnable_dispersion=True and global_dispersion=False: returns (n_features,) array
- MultimodalAmortizedLDA.get_likelihood_per_modality(adata=None, indices=None, batch_size=None)#
Compute log-likelihood for each modality separately.
Higher is better.
- Parameters:
- Return type:
- Returns:
dict[str, float] Dictionary mapping modality names to log-likelihood values
- MultimodalAmortizedLDA.get_likelihood_weights(return_format='dataframe')#
Return per-modality likelihood scaling weights used in the generative model.
- MultimodalAmortizedLDA.get_modality_weights(adata=None, indices=None, batch_size=None, return_format='dataframe')#
Get normalized mixing weights showing how much each modality contributes to topic assignments.
The mixing weights control how much each modality’s encoder output contributes to the final cell-topic distribution (θ) in the Mixture-of-Experts architecture.
Returns weights in range [0, 1] that sum to 1 per cell (or globally for “universal” mode). Higher weight = model relies more on that modality for inferring topics.
Returns a distribution if the weight_mode is ‘cell’, numbers otherwise
- Parameters:
adata (
AnnData|None(default:None)) – AnnData object to compute weights for. If None, uses the registered dataset.indices (
Sequence[int] |None(default:None)) – Indices of cells to include. If None, uses all cells.batch_size (
int|None(default:None)) – Batch size for inference.return_format (
str(default:'dataframe')) – Output format: “dataframe” returns pd.DataFrame (cells × modalities), “dict” returns dict mapping modality names to 1D arrays.
- Return type:
- Returns:
pd.DataFrame or dict[str, np.ndarray] Normalized mixing weights for each cell and modality. For “universal” mode: returns single row with global weights. For “equal” mode: returns uniform weights (1/n_modalities). For “cell” mode: returns per-cell learned weights.
- MultimodalAmortizedLDA.get_normalized_expression(*args, **kwargs)#
Not implemented for this model class.
Available in RNA models that inherit from
RNASeqMixin.- Raises:
- MultimodalAmortizedLDA.get_perplexity(adata=None, indices=None, batch_size=None)#
exp( -log_likelihood / total_counts ) — lower is better.
Uses the observation log-likelihood (via poutine.trace) rather than the full ELBO, so KL terms and regularization bonuses do not inflate the result. Notice that this explodes very easily
- Return type:
- MultimodalAmortizedLDA.get_perplexity_per_modality(adata=None, indices=None, batch_size=None)#
Compute perplexity for each modality separately.
Lower is better. Perplexity = exp(-log_likelihood / N_tokens)
- Parameters:
- Return type:
- Returns:
dict[str, float] Dictionary mapping modality names to perplexity values
- MultimodalAmortizedLDA.get_setup_arg(setup_arg)#
Returns the string provided to setup of a specific setup_arg.
- Return type:
- MultimodalAmortizedLDA.get_state_registry(registry_key)#
Returns the state registry for the AnnDataField registered with this instance.
- Return type:
- MultimodalAmortizedLDA.get_top_features_per_topic(modality, n_features=10, return_scores=False)#
Get top N features for each topic in a specific modality.
- Parameters:
- Return type:
- Returns:
dict[str, list[str]] or dict[str, list[tuple[str, float]]] Dictionary mapping topic names (e.g., ‘topic_0’) to lists of top feature names or (feature_name, score) tuples.
- MultimodalAmortizedLDA.get_topic_diversity(modality=None)#
Compute topic diversity (average pairwise cosine distance) per modality or overall.
- Return type:
- MultimodalAmortizedLDA.get_topic_variance(adata=None, indices=None, batch_size=None, n_samples=100)#
Compute per-topic variance of topic usage across cells.
- Parameters:
adata (
AnnData|None(default:None)) – AnnData object with data. If None, uses the training data.indices (
Sequence[int] |None(default:None)) – Indices of cells to compute variance for. If None, uses all cells.batch_size (
int|None(default:None)) – Batch size for computation. If None, processes all cells at once.n_samples (
int(default:100)) – Number of posterior samples for Monte Carlo estimation (default: 100)
- Return type:
- Returns:
np.ndarray Per-topic variance values, shape (n_topics,) Var(θ[:, k]) = variance of topic k usage across all cells
- MultimodalAmortizedLDA.get_topic_variance_weight()#
Get the topic variance regularization weight.
- Return type:
- Returns:
float Topic variance weight used during training
- MultimodalAmortizedLDA.get_var_names(legacy_mudata_format=False)#
Variable names of input data.
- Return type:
- classmethod MultimodalAmortizedLDA.load(dir_path, adata=None, accelerator='auto', device='auto', prefix=None, backup_url=None, datamodule=None, allowed_classes_names_list=None)#
Instantiate a model from the saved output.
- Parameters:
dir_path (
str) – Path to saved outputs.adata (
AnnData|MuData|None(default:None)) – AnnData organized in the same way as data used to train model. It is not necessary to run setup_anndata, as AnnData is validated against the savedscvisetup dictionary. If None, will check for and load anndata saved with the model. If False, will load the model without AnnData.accelerator (
str(default:'auto')) – Supports passing different accelerator types("cpu", "gpu", "tpu", "ipu", "hpu", "mps, "auto")as well as custom accelerator instances.device (
int|str(default:'auto')) – The device to use. Can be set to a non-negative index (intorstr) or"auto"for automatic selection based on the chosen accelerator. If set to"auto"andacceleratoris not determined to be"cpu", thendevicewill be set to the first available device.prefix (
str|None(default:None)) – Prefix of saved file names.backup_url (
str|None(default:None)) – URL to retrieve saved outputs from if not present on disk.datamodule (
LightningDataModule|None(default:None)) –EXPERIMENTALALightningDataModuleinstance to use for training in place of the defaultDataSplitter. Can only be passed in if the model was not initialized withAnnData.allowed_classes_names_list (
list[str] |None(default:None)) – list of allowed classes names to be loaded (besides the original class name)
- Returns:
Model with loaded state dictionaries.
Examples
>>> model = ModelClass.load(save_path, adata) >>> model.get_....
- static MultimodalAmortizedLDA.load_registry(dir_path, prefix=None)#
Return the full registry saved with the model.
- MultimodalAmortizedLDA.predict(data)#
Predict using the fitted model on the provided data.
- Parameters:
data (The input data for prediction.)
- classmethod MultimodalAmortizedLDA.register_manager(adata_manager)#
Registers an
AnnDataManagerinstance with this model class.Stores the
AnnDataManagerreference in a class-specific manager store. Intended for use in thesetup_anndata()class method followed up by retrieval of theAnnDataManagervia the_get_most_recent_anndata_manager()method in the model init method.Notes
Subsequent calls to this method with an
AnnDataManagerinstance referring to the same underlying AnnData object will overwrite the reference to previousAnnDataManager.
- MultimodalAmortizedLDA.save(dir_path, prefix=None, overwrite=False, save_anndata=False, save_kwargs=None, legacy_mudata_format=False, datamodule=None, **anndata_write_kwargs)#
Save the state of the model.
Neither the trainer optimizer state nor the trainer history are saved. Model files are not expected to be reproducibly saved and loaded across versions until we reach version 1.0.
- Parameters:
dir_path (
str) – Path to a directory.prefix (
str|None(default:None)) – Prefix to prepend to saved file names.overwrite (
bool(default:False)) – Overwrite existing data or not. IfFalseand directory already exists atdir_path, an error will be raised.save_anndata (
bool(default:False)) – If True, also saves the anndatasave_kwargs (
dict|None(default:None)) – Keyword arguments passed intosave().legacy_mudata_format (
bool(default:False)) – IfTrue, saves the modelvar_namesin the legacy format if the model was trained with aMuDataobject. The legacy format is a flat array with variable names across all modalities concatenated, while the new format is a dictionary with keys corresponding to the modality names and values corresponding to the variable names for each modality.datamodule (
LightningDataModule|None(default:None)) –EXPERIMENTALALightningDataModuleinstance to use for training in place of the defaultDataSplitter. Can only be passed in if the model was not initialized withAnnData.anndata_write_kwargs – Kwargs for
write()
- classmethod MultimodalAmortizedLDA.setup_adata_dict(adata_dict, layers=None, spatial_keys=None, categorical_covariate_keys=None, continuous_covariate_keys=None, **kwargs)#
Setup method for dict[str, AnnData] input.
Converts dictionary to concatenated AnnData and processes it for model training.
- Parameters:
adata_dict (
dict[str,AnnData]) – Dictionary mapping modality names to AnnData objects.layers (
dict[str,str|None] |str|None(default:None)) – Layer specifications (per-modality dict, or string for all).spatial_keys (
dict[str,str] |str|None(default:None)) – Spatial graph keys (per-modality dict, or string for all).**kwargs – Additional arguments passed to scvi registration.
- Returns:
adata_concat The processed and registered AnnData object.
- classmethod MultimodalAmortizedLDA.setup_anndata(adata, layer=None, spatial_key=None, modalities=None, layers=None, spatial_keys=None, categorical_covariate_keys=None, continuous_covariate_keys=None, encoder_extra_obsm_key=None, **kwargs)#
Sets up the
AnnDataobject for this model.A mapping will be created between data fields used by this model to their respective locations in adata. None of the data in adata are modified. Only adds fields to adata.
- Parameters:
adata (
AnnData) – AnnData object. Rows represent cells, columns represent features.layer (
str|None(default:None)) – if notNone, uses this as the key inadata.layersfor raw count data.spatial_key (
str|None(default:None)) – Optional key inadata.obsppointing to a precomputed spatial graph.modalities (
list[str] |None(default:None)) – List of modality names (for new API). If None, defaults to [“rna”].layers (
dict[str,str|None] |str|None(default:None)) – Layer specifications (for new API). Can be string or dict.spatial_keys (
dict[str,str] |str|None(default:None)) – Spatial graph keys (for new API). Can be string or dict.categorical_covariate_keys (
list[str] |None(default:None)) – Keys inadata.obsfor categorical covariates (e.g., batch, sample). These will be embedded and used for batch effect correction.continuous_covariate_keys (
list[str] |None(default:None)) – Keys inadata.obsfor continuous covariates (e.g., age, percent_mito). These will be directly concatenated to the covariate representation.encoder_extra_obsm_key (
str|None(default:None)) – Optional key inadata.obsmfor extra encoder input features (e.g., precomputed SGC-smoothed features). These are concatenated to the encoder input but NOT used by the decoder/likelihood. The reconstruction target remains the raw counts in.X.
- classmethod MultimodalAmortizedLDA.setup_data(data, modalities=None, layers=None, spatial_keys=None, table_key='table', categorical_covariate_keys=None, continuous_covariate_keys=None, **kwargs)#
Universal setup method with automatic type detection.
Detects the type of input data (AnnData, MuData, SpatialData, or dict[str, AnnData]) and routes to the appropriate type-specific setup method.
This method follows scvi-tools conventions, performing data registration as a side effect. After calling this method, instantiate the model with the same data object.
- Parameters:
data – Input data of any supported type: - AnnData: single modality data - MuData: multi-modal data - SpatialData: spatial omics data - dict[str, AnnData]: dictionary mapping modality names to AnnData objects
modalities (
list[str] |None(default:None)) – List of modality names to use. If None, uses all available modalities.layers (
dict[str,str|None] |str|None(default:None)) – Layer specifications for data extraction. Can be: - None: use .X for all modalities - str: use same layer for all modalities (e.g., “counts”) - dict: per-modality layer specification (e.g., {“rna”: “counts”, “protein”: “raw”})spatial_keys (
dict[str,str] |str|None(default:None)) – Spatial graph keys in .obsp. Can be: - None: no spatial graphs - str: use same key for all modalities - dict: per-modality spatial keystable_key (
str(default:'table')) – For SpatialData only: key in sdata.tables to extract. Default: “table”.**kwargs – Additional arguments passed to scvi registration.
Examples
>>> # With MuData >>> MultimodalAmortizedLDA.setup_data( ... mdata, modalities=["rna", "protein"], layers="counts", spatial_keys="connectivities" ... ) >>> model = MultimodalAmortizedLDA(mdata, n_topics=20)
>>> # With AnnData >>> MultimodalAmortizedLDA.setup_data(adata, modalities=["rna"], layers="counts") >>> model = MultimodalAmortizedLDA(adata, n_topics=20)
>>> # With dict[str, AnnData] >>> adata_dict = {"rna": adata_rna, "protein": adata_protein} >>> MultimodalAmortizedLDA.setup_data(adata_dict, layers={"rna": "counts", "protein": "raw"}) >>> # For dict, get the processed AnnData from the return value >>> adata_concat = MultimodalAmortizedLDA.setup_data(adata_dict, layers="counts") >>> model = MultimodalAmortizedLDA(adata_concat, n_topics=20)
- classmethod MultimodalAmortizedLDA.setup_mudata(mdata, modality_order=None, layer_dict=None, spatial_key=None, spatial_modality_keys=None, modalities=None, layers=None, spatial_keys=None, categorical_covariate_keys=None, continuous_covariate_keys=None, **kwargs)#
Setup MuData for multimodal AmortizedLDA.
This method stores modality metadata in
mdata.unsand prepares the data for the model without concatenating features.- Parameters:
mdata (
MuData) – MuData object containing multiple modalities.modality_order (
list[str] |None(default:None)) – Order of modalities to use. If None, uses all modalities in mdata.mod.keys(). (Old parameter name, prefermodalities)layer_dict (
dict[str,str] |None(default:None)) – Dictionary mapping modality names to layer names to use for each modality. (Old parameter name, preferlayers)spatial_key (
str|None(default:None)) – Single obsp key applied to all modalities (if spatial_modality_keys is not provided). (Old parameter name, preferspatial_keys)spatial_modality_keys (
dict[str,str] |None(default:None)) – Mapping of modality -> obsp key for modality-specific spatial graphs. (Old parameter name, preferspatial_keys)modalities (
list[str] |None(default:None)) – List of modality names to use (new parameter name, alias for modality_order).layers (
dict[str,str|None] |str|None(default:None)) – Layer specifications (new parameter name). Can be: - None: use .X for all modalities - str: use same layer for all modalities - dict: per-modality layer specificationspatial_keys (
dict[str,str] |str|None(default:None)) – Spatial graph keys (new parameter name). Can be: - None: no spatial graphs - str: use same key for all modalities - dict: per-modality spatial keys**kwargs – Additional arguments passed to setup_anndata.
- Return type:
- Returns:
- mdata
The input MuData object with metadata stored in .uns.
- modality_names
List of modality names in the order they will be processed.
- feat_counts
List of feature counts per modality.
Notes
This method uses extraction utilities for flexible layer and spatial graph handling.
- classmethod MultimodalAmortizedLDA.setup_spatialdata(sdata, table_key='table', modalities=None, layers=None, spatial_keys=None, categorical_covariate_keys=None, continuous_covariate_keys=None, **kwargs)#
Setup method for SpatialData input.
Extracts the specified table from SpatialData and processes it for model training.
- Parameters:
sdata – SpatialData object containing spatial omics data.
table_key (
str(default:'table')) – Key in sdata.tables to extract. Default: “table”.modalities (
list[str] |None(default:None)) – List of modality names to use. If None, uses all available.layers (
dict[str,str|None] |str|None(default:None)) – Layer specifications (per-modality dict, or string for all).spatial_keys (
dict[str,str] |str|None(default:None)) – Spatial graph keys (per-modality dict, or string for all).**kwargs – Additional arguments passed to scvi registration.
- Returns:
adata_concat The processed and registered AnnData object.
- MultimodalAmortizedLDA.to_device(device)#
Move the model to the device.
- Parameters:
device (
str|int|device) – Device to move model to. Options: ‘cpu’ for CPU, integer GPU index (e.g., 0), ‘cuda:X’ where X is the GPU index (e.g. ‘cuda:0’), or a torch.device object (including XLA devices for TPU). See torch.device for more info.
Examples
>>> adata = scvi.data.synthetic_iid() >>> model = scvi.model.SCVI(adata) >>> model.to_device("cpu") # moves model to CPU >>> model.to_device("cuda:0") # moves model to GPU 0 >>> model.to_device(0) # also moves model to GPU 0
- MultimodalAmortizedLDA.train(*args, validation_size=None, **kwargs)#
Override to default to running validation when a split is requested.
scvi’s Trainer defaults to
check_val_every_n_epoch = sys.maxsizeunless early stopping or checkpointing is enabled, which effectively disables the validation loop. Here we set it to 1 when a validation set is present so thatelbo_valis logged every epoch.
- MultimodalAmortizedLDA.transfer_fields(adata, **kwargs)#
Transfer fields from a model to an AnnData object.
- Return type:
- MultimodalAmortizedLDA.update_setup_method_args(setup_method_args)#
Update setup method args.
- Parameters:
setup_method_args (
dict) – This is a bit of a misnomer, this is a dict representing kwargs of the setup method that will be used to update the existing values in the registry of this instance.
- MultimodalAmortizedLDA.view_anndata_setup(adata=None, hide_state_registries=False)#
Print summary of the setup for the initial AnnData or a given AnnData object.
- Parameters:
adata (
AnnData|MuData|None(default:None)) – AnnData object setup withsetup_anndataortransfer_fields().hide_state_registries (
bool(default:False)) – If True, prints a shortened summary without details of each state registry.
- Return type:
- MultimodalAmortizedLDA.view_registry(hide_state_registries=False)#
Prints summary of the registry.
- static MultimodalAmortizedLDA.view_setup_args(dir_path, prefix=None)#
Print args used to setup a saved model.