topomics.models.MultimodalAmortizedLDA

Contents

topomics.models.MultimodalAmortizedLDA#

class topomics.models.MultimodalAmortizedLDA(adata, n_inputs_modalities, likelihoods, n_topics=20, n_hidden=128, cell_topic_prior=None, topic_feature_prior=None, topic_feature_prior_type='logistic_normal', use_feature_background=True, dispersion_rna=1.0, learnable_dispersion=False, global_dispersion=True, modality_names=None, weight_mode='cell', likelihood_weight_mode='none', likelihood_weight_ref='mean', gcn_n_layers=1, gcn_n_pre_layers=0, gcn_conv_type='GATv2Conv', gcn_hidden_dims=None, gcn_alpha_init=0.7, gcn_use_learned_alpha=True, normalize_encoder_inputs=True, encoder_scale_factor=1000000.0, entropy_weight=0.01, topic_variance_weight=1.0, kl_weight=1.0, encode_covariates=True, bg_offset=1e-15, learnable_bg=True, aggregation_type='moe', att_dim=32, spatial_mode='gcn', sgc_n_layers=1, gcn_sampling='approx', gcn_fan_out=None, gcn_conv_first=False)#

Multimodal Amortized LDA with Mixture-of-Experts (MoE)

Extends scvi.model.AmortizedLDA to M modalities with modality-specific encoders and likelihoods. Each modality is encoded separately, and representations are mixed via weighted Gaussian combination before inferring the shared cell-topic distribution θₙ.

Parameters:
  • adata (AnnData) – AnnData with concatenated features (RNA + protein + …).

  • n_inputs_modalities (list[int]) – List with feature counts per modality, in the order they appear in adata.X.

  • likelihoods (list[str]) – Length-matched list of likelihood strings for each modality.

  • n_topics (int (default: 20)) – Number of topics (K).

  • n_hidden (int (default: 128)) – Hidden units of each encoder network.

  • cell_topic_prior (float | Sequence[float] | None (default: None)) – Dirichlet concentration for θₙ. None ⇒ symmetric 1/K.

  • topic_feature_prior (float | Sequence[float] | None (default: None)) – Dirichlet concentration for each ϕₖ,ₘ. None ⇒ symmetric 1/K.

  • weight_mode (str (default: 'cell')) – How to weight modality-specific representations: - "equal": All modalities weighted equally (default) - "universal": Learn a single weight per modality - "cell": Learn per-cell, per-modality weights

Notes

The Mixture-of-Experts architecture processes each modality through a separate encoder network, then combines their latent representations using learned or fixed weights.

Attributes table#

adata

Data attached to model instance.

adata_manager

Manager instance associated with self.adata.

device

The current device that the module's params are on.

get_normalized_function_name

What the get normalized functions name is

history

Returns computed metrics during training.

is_trained

Whether the model has been trained.

registry

Data attached to model instance.

run_id

Returns the run id of the model.

run_name

Returns the run name of the model.

spatial

summary_string

Summary string of the model.

test_indices

Observations that are in test set.

train_indices

Observations that are in train set.

validation_indices

Observations that are in validation set.

Methods table#

check_input(mdata, modalities)

Validate and process the input data.

check_modalities_names()

Standardize and validate modality keys in data_dict.

clear_metric_cache()

Clear the cached metrics.

convert_legacy_save(dir_path, output_dir_path)

Converts a legacy saved model (<v0.15.0) to the updated save format.

cross_modality_score(mod_a, mod_b, *[, ...])

Compute SHARE-Topic–style cross-modal interaction matrix P_{a,b}

data_registry(registry_key)

Returns the object in AnnData associated with the key in the data registry.

deregister_manager([adata])

Deregisters the AnnDataManager instance associated with adata.

differential_abundance(*args, **kwargs)

Not implemented for this model class.

fit(data)

Fit the model to the provided data.

from_data(data[, modalities, layers, ...])

Convenience constructor: setup + instantiation in one call.

from_mudata(mdata[, modality_order, ...])

High-level constructor for multimodal AmortizedLDA from MuData.

get_anndata_manager(adata[, required])

Retrieves the AnnDataManager for a given AnnData object.

get_cell_entropy([adata, indices, ...])

Compute per-cell entropy of cell-topic distributions.

get_cell_topic_dist([adata, indices, ...])

Get the cell-topic matrix Θ (C × K).

get_elbo([adata, indices, batch_size])

Average ELBO across batches (higher is better).

get_entropy([adata, indices, batch_size, ...])

Compute mean entropy of cell-topic distributions.

get_entropy_weight()

Get the entropy regularization weight.

get_feature_topic_dist([modality, ...])

Monte-Carlo estimate of E[ϕₖ,ₘ].

get_from_registry(adata, registry_key)

Returns the object in AnnData associated with the key in the data registry.

get_last_entropy()

Get the mean entropy from the last forward pass through the model.

get_last_topic_variance()

Get the last computed mean topic variance from training.

get_latent_representation([adata, indices, ...])

Infer θₙ for all cells (or subset).

get_learned_dispersion([modality, n_samples])

Get the learned or fixed dispersion parameters.

get_likelihood_per_modality([adata, ...])

Compute log-likelihood for each modality separately.

get_likelihood_weights([return_format])

Return per-modality likelihood scaling weights used in the generative model.

get_modality_weights([adata, indices, ...])

Get normalized mixing weights showing how much each modality contributes to topic assignments.

get_normalized_expression(*args, **kwargs)

Not implemented for this model class.

get_perplexity([adata, indices, batch_size])

exp( -log_likelihood / total_counts ) — lower is better.

get_perplexity_per_modality([adata, ...])

Compute perplexity for each modality separately.

get_setup_arg(setup_arg)

Returns the string provided to setup of a specific setup_arg.

get_state_registry(registry_key)

Returns the state registry for the AnnDataField registered with this instance.

get_top_features_per_topic(modality[, ...])

Get top N features for each topic in a specific modality.

get_topic_diversity([modality])

Compute topic diversity (average pairwise cosine distance) per modality or overall.

get_topic_variance([adata, indices, ...])

Compute per-topic variance of topic usage across cells.

get_topic_variance_weight()

Get the topic variance regularization weight.

get_var_names([legacy_mudata_format])

Variable names of input data.

load(dir_path[, adata, accelerator, device, ...])

Instantiate a model from the saved output.

load_registry(dir_path[, prefix])

Return the full registry saved with the model.

predict(data)

Predict using the fitted model on the provided data.

register_manager(adata_manager)

Registers an AnnDataManager instance with this model class.

save(dir_path[, prefix, overwrite, ...])

Save the state of the model.

setup_adata_dict(adata_dict[, layers, ...])

Setup method for dict[str, AnnData] input.

setup_anndata(adata[, layer, spatial_key, ...])

Sets up the AnnData object for this model.

setup_data(data[, modalities, layers, ...])

Universal setup method with automatic type detection.

setup_mudata(mdata[, modality_order, ...])

Setup MuData for multimodal AmortizedLDA.

setup_spatialdata(sdata[, table_key, ...])

Setup method for SpatialData input.

to_device(device)

Move the model to the device.

train(*args[, validation_size])

Override to default to running validation when a split is requested.

transfer_fields(adata, **kwargs)

Transfer fields from a model to an AnnData object.

update_setup_method_args(setup_method_args)

Update setup method args.

view_anndata_setup([adata, ...])

Print summary of the setup for the initial AnnData or a given AnnData object.

view_registry([hide_state_registries])

Prints summary of the registry.

view_setup_args(dir_path[, prefix])

Print args used to setup a saved model.

view_setup_method_args()

Prints setup kwargs used to produce a given registry.

Attributes#

MultimodalAmortizedLDA.adata#

Data attached to model instance.

MultimodalAmortizedLDA.adata_manager#

Manager instance associated with self.adata.

MultimodalAmortizedLDA.device#

The current device that the module’s params are on.

MultimodalAmortizedLDA.get_normalized_function_name#

What the get normalized functions name is

MultimodalAmortizedLDA.history#

Returns computed metrics during training.

MultimodalAmortizedLDA.is_trained#

Whether the model has been trained.

MultimodalAmortizedLDA.registry#

Data attached to model instance.

MultimodalAmortizedLDA.run_id#

Returns the run id of the model. Used in MLFlow

MultimodalAmortizedLDA.run_name#

Returns the run name of the model. Used in MLFlow

MultimodalAmortizedLDA.spatial: bool = False#
MultimodalAmortizedLDA.summary_string#

Summary string of the model.

MultimodalAmortizedLDA.test_indices#

Observations that are in test set.

MultimodalAmortizedLDA.train_indices#

Observations that are in train set.

MultimodalAmortizedLDA.validation_indices#

Observations that are in validation set.

Methods#

MultimodalAmortizedLDA.check_input(mdata, modalities)#

Validate and process the input data.

Checks that data are adata or mudata objects, and that the modalities are correctly specified.

MultimodalAmortizedLDA.check_modalities_names()#

Standardize and validate modality keys in data_dict.

Maps various synonyms to ‘rna’, ‘protein’, or ‘chromatin’, and rebuilds data_dict with standardized keys.

MultimodalAmortizedLDA.clear_metric_cache()#

Clear the cached metrics.

Call this method after retraining the model to ensure metrics are recomputed with the updated parameters.

classmethod MultimodalAmortizedLDA.convert_legacy_save(dir_path, output_dir_path, overwrite=False, prefix=None, **save_kwargs)#

Converts a legacy saved model (<v0.15.0) to the updated save format.

Parameters:
  • dir_path (str) – Path to the directory where the legacy model is saved.

  • output_dir_path (str) – Path to save converted save files.

  • overwrite (bool (default: False)) – Overwrite existing data or not. If False and directory already exists at output_dir_path, an error will be raised.

  • prefix (str | None (default: None)) – Prefix of saved file names.

  • **save_kwargs – Keyword arguments passed into save().

Return type:

None

MultimodalAmortizedLDA.cross_modality_score(mod_a, mod_b, *, normalise=True, return_df=True)#

Compute SHARE-Topic–style cross-modal interaction matrix P_{a,b}

Parameters:
  • model (fitted topic model with the two accessors above)

  • mod_a (str)

  • mod_b (str)

  • normalise (bool (default: True))

  • return_df (bool (default: True))

Return type:

ndarray | DataFrame

Returns:

P : shape (n_feat_a, n_feat_b) – interaction score between every feature of mod_a and every feature of mod_b

MultimodalAmortizedLDA.data_registry(registry_key)#

Returns the object in AnnData associated with the key in the data registry.

Parameters:

registry_key (str) – key of an object to get from self.data_registry

Return type:

ndarray | DataFrame

Returns:

The requested data.

MultimodalAmortizedLDA.deregister_manager(adata=None)#

Deregisters the AnnDataManager instance associated with adata.

If adata is None, deregisters all AnnDataManager instances in both the class and instance-specific manager stores, except for the one associated with this model instance.

MultimodalAmortizedLDA.differential_abundance(*args, **kwargs)#

Not implemented for this model class.

Available in models that inherit from VAEMixin.

Raises:

NotImplementedError

MultimodalAmortizedLDA.fit(data)#

Fit the model to the provided data.

Parameters:

data (The input data to fit the model.)

classmethod MultimodalAmortizedLDA.from_data(data, modalities=None, layers=None, spatial_keys=None, table_key='table', categorical_covariate_keys=None, continuous_covariate_keys=None, **model_kwargs)#

Convenience constructor: setup + instantiation in one call.

This is equivalent to calling setup_data() followed by the model constructor. Automatically detects input type and handles all preprocessing.

Parameters:
  • data – Input data - can be: - AnnData: Single or concatenated modalities - MuData: Multiple modalities - SpatialData: Spatial omics data - dict[str, AnnData]: Dictionary mapping modality names to AnnData

  • modalities (list[str] | None (default: None)) – Modalities to use. For MuData/SpatialData, subset selection. For AnnData, provide single modality name (default: “rna”). For dict, uses dict keys if None.

  • layers (dict[str, str | None] | str | None (default: None)) – Layer specification: - dict: Per-modality layers {“rna”: “counts”, “protein”: None} - str: Same layer for all modalities - None: Use .X for all

  • spatial_keys (dict[str, str] | str | None (default: None)) – Spatial graph keys in .obsp: - dict: Per-modality spatial keys - str: Same key for all modalities - None: No spatial data

  • table_key (str (default: 'table')) – For SpatialData only: which table to use (default: “table”)

  • **model_kwargs – Additional arguments passed to model constructor: - n_topics: Number of topics (required) - n_hidden: Hidden units (default: 128) - weight_mode: “equal”, “universal”, or “cell” (default: “equal”) - likelihood_weight_mode: “none”, “inverse_features”, “sqrt_inverse_features” (default: “none”) - likelihood_weight_ref: “mean”, “median”, or “max” (default: “mean”) - gcn_n_layers: Number of graph conv layers for spatial encoders (default: 1) - gcn_hidden_dims: List of hidden sizes per graph conv layer - likelihoods: List of likelihoods per modality (“multinomial”, “gamma_poisson”/”nb”, “bernoulli”; auto-inferred if not provided)

Returns:

MultimodalAmortizedLDA Initialized model ready for training

Examples

>>> # From MuData with layer selection
>>> model = MultimodalAmortizedLDA.from_data(
...     mdata, modalities=["rna", "protein"], layers={"rna": "counts"}, n_topics=20
... )
>>> # From SpatialData
>>> model = MultimodalAmortizedLDA.from_data(
...     sdata, table_key="table", layers="counts", spatial_keys="spatial_connectivities", n_topics=20
... )
>>> # From dict of AnnData
>>> model = MultimodalAmortizedLDA.from_data(
...     {"rna": adata_rna, "protein": adata_protein}, layers={"rna": "counts"}, n_topics=20
... )
>>> # From single AnnData
>>> model = MultimodalAmortizedLDA.from_data(adata, modalities=["rna"], layers="counts", n_topics=20)
classmethod MultimodalAmortizedLDA.from_mudata(mdata, modality_order=None, layer_dict=None, spatial_key=None, spatial_modality_keys=None, categorical_covariate_keys=None, continuous_covariate_keys=None, **model_kwargs)#

High-level constructor for multimodal AmortizedLDA from MuData.

Parameters:
  • mdata (MuData) – MuData object containing multiple modalities.

  • modality_order (list[str] | None (default: None)) – Order of modalities to use. If None, uses all modalities in mdata.mod.keys().

  • layer_dict (dict[str, str] | None (default: None)) – Dictionary mapping modality names to layer names to use for each modality.

  • spatial_key (str | None (default: None)) – Single obsp key applied to all modalities (if spatial_modality_keys is not provided).

  • spatial_modality_keys (dict[str, str] | None (default: None)) – Mapping of modality -> obsp key for modality-specific spatial graphs.

  • **model_kwargs – Additional arguments passed to the model constructor. Common arguments include: - n_topics: Number of topics (default: 20) - n_hidden: Hidden units in encoders (default: 128) - weight_mode: “equal”, “universal”, or “cell” (default: “equal”) - likelihood_weight_mode: “none”, “inverse_features”, “sqrt_inverse_features” (default: “none”) - likelihood_weight_ref: “mean”, “median”, or “max” (default: “mean”) - gcn_n_layers: Number of graph conv layers for spatial encoders (default: 1) - gcn_hidden_dims: List of hidden sizes per graph conv layer - likelihoods: List of likelihoods per modality (“multinomial”, “gamma_poisson”/”nb”, “bernoulli”; auto-inferred if not provided)

MultimodalAmortizedLDA.get_anndata_manager(adata, required=False)#

Retrieves the AnnDataManager for a given AnnData object.

Requires self.id has been set. Checks for an AnnDataManager specific to this model instance.

Parameters:
  • adata (AnnData | MuData) – AnnData object to find a manager instance for.

  • required (bool (default: False)) – If True, errors on missing manager. Otherwise, returns None when manager is missing.

Return type:

AnnDataManager | None

MultimodalAmortizedLDA.get_cell_entropy(adata=None, indices=None, batch_size=None, n_samples=100)#

Compute per-cell entropy of cell-topic distributions.

Parameters:
  • adata (AnnData | None (default: None)) – AnnData object with data. If None, uses the training data.

  • indices (Sequence[int] | None (default: None)) – Indices of cells to compute entropy for. If None, uses all cells.

  • batch_size (int | None (default: None)) – Batch size for computation. If None, processes all cells at once.

  • n_samples (int (default: 100)) – Number of posterior samples for Monte Carlo estimation (default: 100)

Return type:

ndarray

Returns:

np.ndarray Per-cell entropy values, shape (n_cells,) H(θ_n) = -Σ_k θ_n,k * log(θ_n,k) for each cell n

MultimodalAmortizedLDA.get_cell_topic_dist(adata=None, indices=None, batch_size=None, n_samples=5000, return_dataframe=False)#

Get the cell-topic matrix Θ (C × K).

Parameters:
  • adata (AnnData | None (default: None)) – AnnData object to use (default: self.adata).

  • indices (Sequence[int] | None (default: None)) – Subset of cells to use.

  • batch_size (int | None (default: None)) – Batch size for inference.

  • n_samples (int (default: 5000)) – Number of samples for Monte Carlo estimation.

  • return_dataframe (bool (default: False)) – If True, return a pd.DataFrame (cells × topics) indexed by adata.obs_names with topic_k columns instead of a raw array.

Return type:

ndarray | DataFrame

Returns:

Θ : np.ndarray | pd.DataFrame Cell-topic matrix, where C is the number of cells and K is the number of topics.

MultimodalAmortizedLDA.get_elbo(adata=None, indices=None, batch_size=None)#

Average ELBO across batches (higher is better).

Note: Pyro’s Trace_ELBO.loss() returns -ELBO; this method negates it so the returned value is the actual ELBO (higher = better fit).

Return type:

float

MultimodalAmortizedLDA.get_entropy(adata=None, indices=None, batch_size=None, normalised=True)#

Compute mean entropy of cell-topic distributions.

Higher entropy means topics are more evenly distributed across cells.

Parameters:
  • adata (AnnData | None (default: None)) – AnnData object to use (default: self.adata).

  • indices (Sequence[int] | None (default: None)) – Subset of cells to use.

  • batch_size (int | None (default: None)) – Batch size for inference.

  • normalised (bool (default: True)) – Whether to normalize cell-topic distributions before computing entropy.

Return type:

float

Returns:

float Mean entropy across cells

MultimodalAmortizedLDA.get_entropy_weight()#

Get the entropy regularization weight.

Return type:

float

Returns:

float Current entropy_weight value used for regularization

MultimodalAmortizedLDA.get_feature_topic_dist(modality=None, n_samples=5000, as_dict=False)#

Monte-Carlo estimate of E[ϕₖ,ₘ].

Parameters:
  • modality (str | int | None (default: None)) – Modality name or index. If provided, return only that modality’s topic-feature distribution; otherwise return all.

  • n_samples (int (default: 5000)) – MC samples from variational posterior.

  • as_dict (bool (default: False)) – If True, return {m: DataFrame} per modality; otherwise concatenate along features (like original single-modality API).

Return type:

dict[int, DataFrame] | DataFrame

Returns:

  • dict of DataFrames (default) – index = feature names, columns = topics

  • or a single concatenated DataFrame if as_dict=False.

MultimodalAmortizedLDA.get_from_registry(adata, registry_key)#

Returns the object in AnnData associated with the key in the data registry.

AnnData object should be registered with the model prior to calling this function via the self._validate_anndata method.

Parameters:
  • registry_key (str) – key of object to get from the data registry.

  • adata (AnnData | MuData) – AnnData to pull data from.

Return type:

ndarray

Returns:

The requested data as a NumPy array.

MultimodalAmortizedLDA.get_last_entropy()#

Get the mean entropy from the last forward pass through the model.

Return type:

float | None

Returns:

float | None Mean cell-topic entropy from last forward pass, or None if not available

MultimodalAmortizedLDA.get_last_topic_variance()#

Get the last computed mean topic variance from training.

Return type:

float | None

Returns:

float | None Mean topic variance from the last forward pass, or None if not available

MultimodalAmortizedLDA.get_latent_representation(adata=None, indices=None, batch_size=None, n_samples=5000, return_dataframe=False)#

Infer θₙ for all cells (or subset).

Alias for get_cell_topic_dist(). Returns the cell-topic matrix Θ (C × K) of softmax-normalized expectations as a np.ndarray by default, or a pd.DataFrame if return_dataframe=True.

Return type:

ndarray | DataFrame

MultimodalAmortizedLDA.get_learned_dispersion(modality=None, n_samples=1000)#

Get the learned or fixed dispersion parameters.

Parameters:
  • modality (str | int | None (default: None)) – Modality name or index. If None, returns dispersion for all NB modalities as a dict.

  • n_samples (int (default: 1000)) – Number of Monte Carlo samples for learned dispersion. Default: 1000.

Return type:

dict[str, ndarray] | ndarray

Returns:

dict[str, np.ndarray] | np.ndarray If modality is None: dict mapping modality names to dispersion arrays If modality specified: dispersion array for that modality

Notes

  • If learnable_dispersion=False, returns the fixed dispersion value

  • If learnable_dispersion=True and global_dispersion=True: returns (1,) array

  • If learnable_dispersion=True and global_dispersion=False: returns (n_features,) array

MultimodalAmortizedLDA.get_likelihood_per_modality(adata=None, indices=None, batch_size=None)#

Compute log-likelihood for each modality separately.

Higher is better.

Parameters:
  • adata (AnnData | None (default: None)) – AnnData object to use (default: self.adata).

  • indices (Sequence[int] | None (default: None)) – Subset of cells to use.

  • batch_size (int | None (default: None)) – Batch size for inference.

Return type:

dict[str, float]

Returns:

dict[str, float] Dictionary mapping modality names to log-likelihood values

MultimodalAmortizedLDA.get_likelihood_weights(return_format='dataframe')#

Return per-modality likelihood scaling weights used in the generative model.

Parameters:

return_format (str (default: 'dataframe')) – “dataframe” returns a single-row DataFrame (modalities as columns), “dict” returns a mapping of modality name -> weight.

Return type:

DataFrame | dict[str, float]

Returns:

pd.DataFrame or dict[str, float] Likelihood weights for each modality.

MultimodalAmortizedLDA.get_modality_weights(adata=None, indices=None, batch_size=None, return_format='dataframe')#

Get normalized mixing weights showing how much each modality contributes to topic assignments.

The mixing weights control how much each modality’s encoder output contributes to the final cell-topic distribution (θ) in the Mixture-of-Experts architecture.

Returns weights in range [0, 1] that sum to 1 per cell (or globally for “universal” mode). Higher weight = model relies more on that modality for inferring topics.

Returns a distribution if the weight_mode is ‘cell’, numbers otherwise

Parameters:
  • adata (AnnData | None (default: None)) – AnnData object to compute weights for. If None, uses the registered dataset.

  • indices (Sequence[int] | None (default: None)) – Indices of cells to include. If None, uses all cells.

  • batch_size (int | None (default: None)) – Batch size for inference.

  • return_format (str (default: 'dataframe')) – Output format: “dataframe” returns pd.DataFrame (cells × modalities), “dict” returns dict mapping modality names to 1D arrays.

Return type:

DataFrame | dict[str, ndarray]

Returns:

pd.DataFrame or dict[str, np.ndarray] Normalized mixing weights for each cell and modality. For “universal” mode: returns single row with global weights. For “equal” mode: returns uniform weights (1/n_modalities). For “cell” mode: returns per-cell learned weights.

MultimodalAmortizedLDA.get_normalized_expression(*args, **kwargs)#

Not implemented for this model class.

Available in RNA models that inherit from RNASeqMixin.

Raises:

NotImplementedError

MultimodalAmortizedLDA.get_perplexity(adata=None, indices=None, batch_size=None)#

exp( -log_likelihood / total_counts ) — lower is better.

Uses the observation log-likelihood (via poutine.trace) rather than the full ELBO, so KL terms and regularization bonuses do not inflate the result. Notice that this explodes very easily

Return type:

float

MultimodalAmortizedLDA.get_perplexity_per_modality(adata=None, indices=None, batch_size=None)#

Compute perplexity for each modality separately.

Lower is better. Perplexity = exp(-log_likelihood / N_tokens)

Parameters:
  • adata (AnnData | None (default: None)) – AnnData object to use (default: self.adata).

  • indices (Sequence[int] | None (default: None)) – Subset of cells to use.

  • batch_size (int | None (default: None)) – Batch size for inference.

Return type:

dict[str, float]

Returns:

dict[str, float] Dictionary mapping modality names to perplexity values

MultimodalAmortizedLDA.get_setup_arg(setup_arg)#

Returns the string provided to setup of a specific setup_arg.

Return type:

attrdict

MultimodalAmortizedLDA.get_state_registry(registry_key)#

Returns the state registry for the AnnDataField registered with this instance.

Return type:

attrdict

MultimodalAmortizedLDA.get_top_features_per_topic(modality, n_features=10, return_scores=False)#

Get top N features for each topic in a specific modality.

Parameters:
  • modality (str) – Modality name (e.g., ‘rna’, ‘protein’, ‘chromatin’)

  • n_features (int (default: 10)) – Number of top features to return per topic (default: 10)

  • return_scores (bool (default: False)) – If True, return (feature_name, score) tuples. If False, return feature names only (default: False).

Return type:

dict[str, list[str]] | dict[str, list[tuple[str, float]]]

Returns:

dict[str, list[str]] or dict[str, list[tuple[str, float]]] Dictionary mapping topic names (e.g., ‘topic_0’) to lists of top feature names or (feature_name, score) tuples.

MultimodalAmortizedLDA.get_topic_diversity(modality=None)#

Compute topic diversity (average pairwise cosine distance) per modality or overall.

Return type:

float

MultimodalAmortizedLDA.get_topic_variance(adata=None, indices=None, batch_size=None, n_samples=100)#

Compute per-topic variance of topic usage across cells.

Parameters:
  • adata (AnnData | None (default: None)) – AnnData object with data. If None, uses the training data.

  • indices (Sequence[int] | None (default: None)) – Indices of cells to compute variance for. If None, uses all cells.

  • batch_size (int | None (default: None)) – Batch size for computation. If None, processes all cells at once.

  • n_samples (int (default: 100)) – Number of posterior samples for Monte Carlo estimation (default: 100)

Return type:

ndarray

Returns:

np.ndarray Per-topic variance values, shape (n_topics,) Var(θ[:, k]) = variance of topic k usage across all cells

MultimodalAmortizedLDA.get_topic_variance_weight()#

Get the topic variance regularization weight.

Return type:

float

Returns:

float Topic variance weight used during training

MultimodalAmortizedLDA.get_var_names(legacy_mudata_format=False)#

Variable names of input data.

Return type:

dict

classmethod MultimodalAmortizedLDA.load(dir_path, adata=None, accelerator='auto', device='auto', prefix=None, backup_url=None, datamodule=None, allowed_classes_names_list=None)#

Instantiate a model from the saved output.

Parameters:
  • dir_path (str) – Path to saved outputs.

  • adata (AnnData | MuData | None (default: None)) – AnnData organized in the same way as data used to train model. It is not necessary to run setup_anndata, as AnnData is validated against the saved scvi setup dictionary. If None, will check for and load anndata saved with the model. If False, will load the model without AnnData.

  • accelerator (str (default: 'auto')) – Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "hpu", "mps, "auto") as well as custom accelerator instances.

  • device (int | str (default: 'auto')) – The device to use. Can be set to a non-negative index (int or str) or "auto" for automatic selection based on the chosen accelerator. If set to "auto" and accelerator is not determined to be "cpu", then device will be set to the first available device.

  • prefix (str | None (default: None)) – Prefix of saved file names.

  • backup_url (str | None (default: None)) – URL to retrieve saved outputs from if not present on disk.

  • datamodule (LightningDataModule | None (default: None)) – EXPERIMENTAL A LightningDataModule instance to use for training in place of the default DataSplitter. Can only be passed in if the model was not initialized with AnnData.

  • allowed_classes_names_list (list[str] | None (default: None)) – list of allowed classes names to be loaded (besides the original class name)

Returns:

Model with loaded state dictionaries.

Examples

>>> model = ModelClass.load(save_path, adata)
>>> model.get_....
static MultimodalAmortizedLDA.load_registry(dir_path, prefix=None)#

Return the full registry saved with the model.

Parameters:
  • dir_path (str) – Path to saved outputs.

  • prefix (str | None (default: None)) – Prefix of saved file names.

Return type:

dict

Returns:

The full registry saved with the model

MultimodalAmortizedLDA.predict(data)#

Predict using the fitted model on the provided data.

Parameters:

data (The input data for prediction.)

classmethod MultimodalAmortizedLDA.register_manager(adata_manager)#

Registers an AnnDataManager instance with this model class.

Stores the AnnDataManager reference in a class-specific manager store. Intended for use in the setup_anndata() class method followed up by retrieval of the AnnDataManager via the _get_most_recent_anndata_manager() method in the model init method.

Notes

Subsequent calls to this method with an AnnDataManager instance referring to the same underlying AnnData object will overwrite the reference to previous AnnDataManager.

MultimodalAmortizedLDA.save(dir_path, prefix=None, overwrite=False, save_anndata=False, save_kwargs=None, legacy_mudata_format=False, datamodule=None, **anndata_write_kwargs)#

Save the state of the model.

Neither the trainer optimizer state nor the trainer history are saved. Model files are not expected to be reproducibly saved and loaded across versions until we reach version 1.0.

Parameters:
  • dir_path (str) – Path to a directory.

  • prefix (str | None (default: None)) – Prefix to prepend to saved file names.

  • overwrite (bool (default: False)) – Overwrite existing data or not. If False and directory already exists at dir_path, an error will be raised.

  • save_anndata (bool (default: False)) – If True, also saves the anndata

  • save_kwargs (dict | None (default: None)) – Keyword arguments passed into save().

  • legacy_mudata_format (bool (default: False)) – If True, saves the model var_names in the legacy format if the model was trained with a MuData object. The legacy format is a flat array with variable names across all modalities concatenated, while the new format is a dictionary with keys corresponding to the modality names and values corresponding to the variable names for each modality.

  • datamodule (LightningDataModule | None (default: None)) – EXPERIMENTAL A LightningDataModule instance to use for training in place of the default DataSplitter. Can only be passed in if the model was not initialized with AnnData.

  • anndata_write_kwargs – Kwargs for write()

classmethod MultimodalAmortizedLDA.setup_adata_dict(adata_dict, layers=None, spatial_keys=None, categorical_covariate_keys=None, continuous_covariate_keys=None, **kwargs)#

Setup method for dict[str, AnnData] input.

Converts dictionary to concatenated AnnData and processes it for model training.

Parameters:
  • adata_dict (dict[str, AnnData]) – Dictionary mapping modality names to AnnData objects.

  • layers (dict[str, str | None] | str | None (default: None)) – Layer specifications (per-modality dict, or string for all).

  • spatial_keys (dict[str, str] | str | None (default: None)) – Spatial graph keys (per-modality dict, or string for all).

  • **kwargs – Additional arguments passed to scvi registration.

Returns:

adata_concat The processed and registered AnnData object.

classmethod MultimodalAmortizedLDA.setup_anndata(adata, layer=None, spatial_key=None, modalities=None, layers=None, spatial_keys=None, categorical_covariate_keys=None, continuous_covariate_keys=None, encoder_extra_obsm_key=None, **kwargs)#

Sets up the AnnData object for this model.

A mapping will be created between data fields used by this model to their respective locations in adata. None of the data in adata are modified. Only adds fields to adata.

Parameters:
  • adata (AnnData) – AnnData object. Rows represent cells, columns represent features.

  • layer (str | None (default: None)) – if not None, uses this as the key in adata.layers for raw count data.

  • spatial_key (str | None (default: None)) – Optional key in adata.obsp pointing to a precomputed spatial graph.

  • modalities (list[str] | None (default: None)) – List of modality names (for new API). If None, defaults to [“rna”].

  • layers (dict[str, str | None] | str | None (default: None)) – Layer specifications (for new API). Can be string or dict.

  • spatial_keys (dict[str, str] | str | None (default: None)) – Spatial graph keys (for new API). Can be string or dict.

  • categorical_covariate_keys (list[str] | None (default: None)) – Keys in adata.obs for categorical covariates (e.g., batch, sample). These will be embedded and used for batch effect correction.

  • continuous_covariate_keys (list[str] | None (default: None)) – Keys in adata.obs for continuous covariates (e.g., age, percent_mito). These will be directly concatenated to the covariate representation.

  • encoder_extra_obsm_key (str | None (default: None)) – Optional key in adata.obsm for extra encoder input features (e.g., precomputed SGC-smoothed features). These are concatenated to the encoder input but NOT used by the decoder/likelihood. The reconstruction target remains the raw counts in .X.

classmethod MultimodalAmortizedLDA.setup_data(data, modalities=None, layers=None, spatial_keys=None, table_key='table', categorical_covariate_keys=None, continuous_covariate_keys=None, **kwargs)#

Universal setup method with automatic type detection.

Detects the type of input data (AnnData, MuData, SpatialData, or dict[str, AnnData]) and routes to the appropriate type-specific setup method.

This method follows scvi-tools conventions, performing data registration as a side effect. After calling this method, instantiate the model with the same data object.

Parameters:
  • data – Input data of any supported type: - AnnData: single modality data - MuData: multi-modal data - SpatialData: spatial omics data - dict[str, AnnData]: dictionary mapping modality names to AnnData objects

  • modalities (list[str] | None (default: None)) – List of modality names to use. If None, uses all available modalities.

  • layers (dict[str, str | None] | str | None (default: None)) – Layer specifications for data extraction. Can be: - None: use .X for all modalities - str: use same layer for all modalities (e.g., “counts”) - dict: per-modality layer specification (e.g., {“rna”: “counts”, “protein”: “raw”})

  • spatial_keys (dict[str, str] | str | None (default: None)) – Spatial graph keys in .obsp. Can be: - None: no spatial graphs - str: use same key for all modalities - dict: per-modality spatial keys

  • table_key (str (default: 'table')) – For SpatialData only: key in sdata.tables to extract. Default: “table”.

  • **kwargs – Additional arguments passed to scvi registration.

Examples

>>> # With MuData
>>> MultimodalAmortizedLDA.setup_data(
...     mdata, modalities=["rna", "protein"], layers="counts", spatial_keys="connectivities"
... )
>>> model = MultimodalAmortizedLDA(mdata, n_topics=20)
>>> # With AnnData
>>> MultimodalAmortizedLDA.setup_data(adata, modalities=["rna"], layers="counts")
>>> model = MultimodalAmortizedLDA(adata, n_topics=20)
>>> # With dict[str, AnnData]
>>> adata_dict = {"rna": adata_rna, "protein": adata_protein}
>>> MultimodalAmortizedLDA.setup_data(adata_dict, layers={"rna": "counts", "protein": "raw"})
>>> # For dict, get the processed AnnData from the return value
>>> adata_concat = MultimodalAmortizedLDA.setup_data(adata_dict, layers="counts")
>>> model = MultimodalAmortizedLDA(adata_concat, n_topics=20)
classmethod MultimodalAmortizedLDA.setup_mudata(mdata, modality_order=None, layer_dict=None, spatial_key=None, spatial_modality_keys=None, modalities=None, layers=None, spatial_keys=None, categorical_covariate_keys=None, continuous_covariate_keys=None, **kwargs)#

Setup MuData for multimodal AmortizedLDA.

This method stores modality metadata in mdata.uns and prepares the data for the model without concatenating features.

Parameters:
  • mdata (MuData) – MuData object containing multiple modalities.

  • modality_order (list[str] | None (default: None)) – Order of modalities to use. If None, uses all modalities in mdata.mod.keys(). (Old parameter name, prefer modalities)

  • layer_dict (dict[str, str] | None (default: None)) – Dictionary mapping modality names to layer names to use for each modality. (Old parameter name, prefer layers)

  • spatial_key (str | None (default: None)) – Single obsp key applied to all modalities (if spatial_modality_keys is not provided). (Old parameter name, prefer spatial_keys)

  • spatial_modality_keys (dict[str, str] | None (default: None)) – Mapping of modality -> obsp key for modality-specific spatial graphs. (Old parameter name, prefer spatial_keys)

  • modalities (list[str] | None (default: None)) – List of modality names to use (new parameter name, alias for modality_order).

  • layers (dict[str, str | None] | str | None (default: None)) – Layer specifications (new parameter name). Can be: - None: use .X for all modalities - str: use same layer for all modalities - dict: per-modality layer specification

  • spatial_keys (dict[str, str] | str | None (default: None)) – Spatial graph keys (new parameter name). Can be: - None: no spatial graphs - str: use same key for all modalities - dict: per-modality spatial keys

  • **kwargs – Additional arguments passed to setup_anndata.

Return type:

tuple[MuData, list[str], list[int]]

Returns:

mdata

The input MuData object with metadata stored in .uns.

modality_names

List of modality names in the order they will be processed.

feat_counts

List of feature counts per modality.

Notes

This method uses extraction utilities for flexible layer and spatial graph handling.

classmethod MultimodalAmortizedLDA.setup_spatialdata(sdata, table_key='table', modalities=None, layers=None, spatial_keys=None, categorical_covariate_keys=None, continuous_covariate_keys=None, **kwargs)#

Setup method for SpatialData input.

Extracts the specified table from SpatialData and processes it for model training.

Parameters:
  • sdata – SpatialData object containing spatial omics data.

  • table_key (str (default: 'table')) – Key in sdata.tables to extract. Default: “table”.

  • modalities (list[str] | None (default: None)) – List of modality names to use. If None, uses all available.

  • layers (dict[str, str | None] | str | None (default: None)) – Layer specifications (per-modality dict, or string for all).

  • spatial_keys (dict[str, str] | str | None (default: None)) – Spatial graph keys (per-modality dict, or string for all).

  • **kwargs – Additional arguments passed to scvi registration.

Returns:

adata_concat The processed and registered AnnData object.

MultimodalAmortizedLDA.to_device(device)#

Move the model to the device.

Parameters:

device (str | int | device) – Device to move model to. Options: ‘cpu’ for CPU, integer GPU index (e.g., 0), ‘cuda:X’ where X is the GPU index (e.g. ‘cuda:0’), or a torch.device object (including XLA devices for TPU). See torch.device for more info.

Examples

>>> adata = scvi.data.synthetic_iid()
>>> model = scvi.model.SCVI(adata)
>>> model.to_device("cpu")  # moves model to CPU
>>> model.to_device("cuda:0")  # moves model to GPU 0
>>> model.to_device(0)  # also moves model to GPU 0
MultimodalAmortizedLDA.train(*args, validation_size=None, **kwargs)#

Override to default to running validation when a split is requested.

scvi’s Trainer defaults to check_val_every_n_epoch = sys.maxsize unless early stopping or checkpointing is enabled, which effectively disables the validation loop. Here we set it to 1 when a validation set is present so that elbo_val is logged every epoch.

MultimodalAmortizedLDA.transfer_fields(adata, **kwargs)#

Transfer fields from a model to an AnnData object.

Return type:

AnnData

MultimodalAmortizedLDA.update_setup_method_args(setup_method_args)#

Update setup method args.

Parameters:

setup_method_args (dict) – This is a bit of a misnomer, this is a dict representing kwargs of the setup method that will be used to update the existing values in the registry of this instance.

MultimodalAmortizedLDA.view_anndata_setup(adata=None, hide_state_registries=False)#

Print summary of the setup for the initial AnnData or a given AnnData object.

Parameters:
  • adata (AnnData | MuData | None (default: None)) – AnnData object setup with setup_anndata or transfer_fields().

  • hide_state_registries (bool (default: False)) – If True, prints a shortened summary without details of each state registry.

Return type:

None

MultimodalAmortizedLDA.view_registry(hide_state_registries=False)#

Prints summary of the registry.

Parameters:

hide_state_registries (bool (default: False)) – If True, prints a shortened summary without details of each state registry.

Return type:

None

static MultimodalAmortizedLDA.view_setup_args(dir_path, prefix=None)#

Print args used to setup a saved model.

Parameters:
  • dir_path (str) – Path to saved outputs.

  • prefix (str | None (default: None)) – Prefix of saved file names.

Return type:

None

MultimodalAmortizedLDA.view_setup_method_args()#

Prints setup kwargs used to produce a given registry.

Parameters:

registry – Registry produced by an AnnDataManager.

Return type:

None