topomics.models.SVEM_LDA_Multi#

class topomics.models.SVEM_LDA_Multi(mdata, modalities=None, n_topics=20, batch_size=512, feature_frac=1.0, alpha=0.1, device='cuda', entropy_penalty=0.0, mod_weights=None)#

Stochastic Variational EM for MuData with shared topics across modalities.

Args:
mdata: MuDataType | dict[str, AnnData] | dict[str, Tensor]

Multi-modal data container.

modalities: list[str] | None

Names corresponding to each AnnData in a list input.

n_topics: int

Number of topics to learn.

batch_size: int

Batch size for training.

feature_frac: float

Fraction of features to use for training.

alpha: float

Dirichlet prior parameter for topic distributions.

device: str

Device to run the model on (‘cuda’ or ‘cpu’).

entropy_penalty: float

Entropy penalty for regularization.

mod_weights: dict[str, float] | None

Weights for each modality.

Attributes table#

Methods table#

cell_topic_distribution([normalised])

Returns the cell-topic distribution (γ) for all cells.

check_input(mdata, modalities)

Validate and process the input data.

check_modalities_names()

Standardize and validate modality keys in data_dict.

clear_metric_cache()

Clear the cached metrics.

cross_modality_score(mod_a, mod_b, *[, ...])

Compute SHARE-Topic–style cross-modal interaction matrix P_{a,b}

diagnostics([loader, mod_weights, inner_iters])

Computes corpus log-likelihood, token-level perplexity and mean entropy of θ̂_c for the current variational parameters.

fit([n_epochs, inner_iters, kappa, tau0, ...])

Fits the model to the data using stochastic variational EM.

get_cell_topic_dist()

Get the cell-topic matrix Θ (C × K).

get_entropy([normalised])

Compute mean entropy of cell-topic distributions.

get_feature_topic_dist(modality)

Get the feature-topic matrix Φ (K × G).

get_likelihood_per_modality(**kwargs)

Compute log-likelihood for each modality separately.

get_modality_weights(**kwargs)

Get normalized mixing weights showing how much each modality contributes to topic assignments.

get_perplexity(**kwargs)

Compute perplexity (reconstruction quality).

get_perplexity_per_modality(**kwargs)

Compute perplexity for each modality separately.

get_top_features_per_topic(modality[, ...])

Get top N features for each topic in a specific modality.

get_topic_diversity([modality])

Compute topic diversity as average pairwise cosine distance.

predict(data)

Predict using the fitted model on the provided data.

topic_by_feature(mod)

Returns the topic-by-feature matrix for the specified modality.

Attributes#

SVEM_LDA_Multi.spatial: bool = False#

Methods#

SVEM_LDA_Multi.cell_topic_distribution(normalised=True)#

Returns the cell-topic distribution (γ) for all cells.

If normalised is True, returns the normalised distribution.

SVEM_LDA_Multi.check_input(mdata, modalities)#

Validate and process the input data.

Checks that data are adata or mudata objects, and that the modalities are correctly specified.

SVEM_LDA_Multi.check_modalities_names()#

Standardize and validate modality keys in data_dict.

Maps various synonyms to ‘rna’, ‘protein’, or ‘chromatin’, and rebuilds data_dict with standardized keys.

SVEM_LDA_Multi.clear_metric_cache()#

Clear the cached metrics.

Call this method after retraining the model to ensure metrics are recomputed with the updated parameters.

SVEM_LDA_Multi.cross_modality_score(mod_a, mod_b, *, normalise=True, return_df=True)#

Compute SHARE-Topic–style cross-modal interaction matrix P_{a,b}

Parameters:
  • model (fitted topic model with the two accessors above)

  • mod_a (str)

  • mod_b (str)

  • normalise (bool (default: True))

  • return_df (bool (default: True))

Return type:

ndarray | DataFrame

Returns:

P : shape (n_feat_a, n_feat_b) – interaction score between every feature of mod_a and every feature of mod_b

SVEM_LDA_Multi.diagnostics(loader=None, mod_weights=None, inner_iters=3)#

Computes corpus log-likelihood, token-level perplexity and mean entropy of θ̂_c for the current variational parameters.

Parameters:
  • loader (DataLoader | None (default: None)) – If None, reuse self.loader (full data, no shuffle).

  • mod_weights (dict[str, float] | None (default: None))

  • inner_iters (int (default: 3)) – for held-out cells (default 3).

Return type:

dict[str, float]

Returns:

dict( log_lik = … , perplexity = … , entropy = … )

SVEM_LDA_Multi.fit(n_epochs=100, inner_iters=2, kappa=0.6, tau0=1024.0, lr_mult=1.0, verbose=True)#

Fits the model to the data using stochastic variational EM.

Parameters:
  • n_epochs (int (default: 100)) – Number of epochs to train.

  • inner_iters (int (default: 2)) – Number of inner iterations to re-estimate γ.

  • kappa (float (default: 0.6)) – Learning rate decay parameter.

  • tau0 (float (default: 1024.0)) – Learning rate offset.

  • lr_mult (float) – Learning rate multiplier.

  • verbose (bool) – Whether to print progress and metrics.

SVEM_LDA_Multi.get_cell_topic_dist()#

Get the cell-topic matrix Θ (C × K).

Return type:

ndarray

Returns:

-Θ (ndarray) Cell-topic matrix, where C is the number of cells and K is the number of topics.

SVEM_LDA_Multi.get_entropy(normalised=True)#

Compute mean entropy of cell-topic distributions.

Higher entropy means topics are more evenly distributed across cells. This measures the uncertainty in topic assignments per cell.

Parameters:

normalised (bool (default: True)) – Whether to normalize cell-topic distributions before computing entropy. If True, ensures distributions sum to 1 (default: True).

Return type:

float

Returns:

float Mean entropy across all cells

SVEM_LDA_Multi.get_feature_topic_dist(modality)#

Get the feature-topic matrix Φ (K × G).

Parameters:

modality (str) – The name of the modality for which to retrieve the feature-topic matrix.

Return type:

ndarray | DataFrame

Returns:

Φ : np.ndarray or pd.DataFrame Feature-topic matrix, where K is the number of topics and G is the number of features. If the modality has feature names, returns a DataFrame with those names.

SVEM_LDA_Multi.get_likelihood_per_modality(**kwargs)#

Compute log-likelihood for each modality separately.

Higher is better.

Return type:

dict[str, float]

Returns:

dict[str, float] Dictionary mapping modality names to log-likelihood values

SVEM_LDA_Multi.get_modality_weights(**kwargs)#

Get normalized mixing weights showing how much each modality contributes to topic assignments.

Only applicable for multimodal models with mixture-of-experts or similar architectures. Returns weights in range [0, 1] that sum to 1 per cell. Higher weight = model relies more on that modality for inferring topics.

Return type:

DataFrame | dict[str, ndarray]

Returns:

pd.DataFrame or dict[str, np.ndarray] Normalized mixing weights for each cell and modality. DataFrame: cells × modalities Dict: modality name → weights array

SVEM_LDA_Multi.get_perplexity(**kwargs)#

Compute perplexity (reconstruction quality).

Lower is better. Perplexity = exp(-log_likelihood / N_tokens)

Return type:

float

Returns:

float Perplexity score

SVEM_LDA_Multi.get_perplexity_per_modality(**kwargs)#

Compute perplexity for each modality separately.

Lower is better. Perplexity = exp(-log_likelihood / N_tokens)

Return type:

dict[str, float]

Returns:

dict[str, float] Dictionary mapping modality names to perplexity values

SVEM_LDA_Multi.get_top_features_per_topic(modality, n_features=10, return_scores=False)#

Get top N features for each topic in a specific modality.

Parameters:
  • modality (str) – Modality name (e.g., ‘rna’, ‘protein’, ‘chromatin’)

  • n_features (int (default: 10)) – Number of top features to return per topic (default: 10)

  • return_scores (bool (default: False)) – If True, return (feature_name, score) tuples. If False, return feature names only (default: False).

Return type:

dict[str, list[str]] | dict[str, list[tuple[str, float]]]

Returns:

dict[str, list[str]] or dict[str, list[tuple[str, float]]] Dictionary mapping topic names (e.g., ‘topic_0’) to lists of top feature names or (feature_name, score) tuples.

SVEM_LDA_Multi.get_topic_diversity(modality=None)#

Compute topic diversity as average pairwise cosine distance.

Higher values indicate more distinct topics. This metric measures how different the topic-feature distributions are from each other.

Parameters:

modality (str | None (default: None)) – If provided, compute diversity for this specific modality’s feature-topic distribution. If None, compute diversity averaged across all modalities (default: None).

Return type:

float

Returns:

float Average pairwise cosine distance between topic distributions (0-1). Higher = more diverse/distinct topics.

SVEM_LDA_Multi.predict(data)#

Predict using the fitted model on the provided data.

Parameters:

data (The input data for prediction.)

SVEM_LDA_Multi.topic_by_feature(mod)#

Returns the topic-by-feature matrix for the specified modality.

Return type:

Tensor