topomics.models.SVEM_LDA_Multi#
- class topomics.models.SVEM_LDA_Multi(mdata, modalities=None, n_topics=20, batch_size=512, feature_frac=1.0, alpha=0.1, device='cuda', entropy_penalty=0.0, mod_weights=None)#
Stochastic Variational EM for MuData with shared topics across modalities.
- Args:
- mdata: MuDataType | dict[str, AnnData] | dict[str, Tensor]
Multi-modal data container.
- modalities: list[str] | None
Names corresponding to each AnnData in a list input.
- n_topics: int
Number of topics to learn.
- batch_size: int
Batch size for training.
- feature_frac: float
Fraction of features to use for training.
- alpha: float
Dirichlet prior parameter for topic distributions.
- device: str
Device to run the model on (‘cuda’ or ‘cpu’).
- entropy_penalty: float
Entropy penalty for regularization.
- mod_weights: dict[str, float] | None
Weights for each modality.
Attributes table#
Methods table#
|
Returns the cell-topic distribution (γ) for all cells. |
|
Validate and process the input data. |
Standardize and validate modality keys in data_dict. |
|
Clear the cached metrics. |
|
|
Compute SHARE-Topic–style cross-modal interaction matrix P_{a,b} |
|
Computes corpus log-likelihood, token-level perplexity and mean entropy of θ̂_c for the current variational parameters. |
|
Fits the model to the data using stochastic variational EM. |
Get the cell-topic matrix Θ (C × K). |
|
|
Compute mean entropy of cell-topic distributions. |
|
Get the feature-topic matrix Φ (K × G). |
|
Compute log-likelihood for each modality separately. |
|
Get normalized mixing weights showing how much each modality contributes to topic assignments. |
|
Compute perplexity (reconstruction quality). |
|
Compute perplexity for each modality separately. |
|
Get top N features for each topic in a specific modality. |
|
Compute topic diversity as average pairwise cosine distance. |
|
Predict using the fitted model on the provided data. |
|
Returns the topic-by-feature matrix for the specified modality. |
Attributes#
Methods#
- SVEM_LDA_Multi.cell_topic_distribution(normalised=True)#
Returns the cell-topic distribution (γ) for all cells.
If
normalisedis True, returns the normalised distribution.
- SVEM_LDA_Multi.check_input(mdata, modalities)#
Validate and process the input data.
Checks that data are adata or mudata objects, and that the modalities are correctly specified.
- SVEM_LDA_Multi.check_modalities_names()#
Standardize and validate modality keys in data_dict.
Maps various synonyms to ‘rna’, ‘protein’, or ‘chromatin’, and rebuilds data_dict with standardized keys.
- SVEM_LDA_Multi.clear_metric_cache()#
Clear the cached metrics.
Call this method after retraining the model to ensure metrics are recomputed with the updated parameters.
- SVEM_LDA_Multi.cross_modality_score(mod_a, mod_b, *, normalise=True, return_df=True)#
Compute SHARE-Topic–style cross-modal interaction matrix P_{a,b}
- Parameters:
- Return type:
- Returns:
P : shape (n_feat_a, n_feat_b) – interaction score between every feature of
mod_aand every feature ofmod_b
- SVEM_LDA_Multi.diagnostics(loader=None, mod_weights=None, inner_iters=3)#
Computes corpus log-likelihood, token-level perplexity and mean entropy of θ̂_c for the current variational parameters.
- Parameters:
- Return type:
- Returns:
dict( log_lik = … , perplexity = … , entropy = … )
- SVEM_LDA_Multi.fit(n_epochs=100, inner_iters=2, kappa=0.6, tau0=1024.0, lr_mult=1.0, verbose=True)#
Fits the model to the data using stochastic variational EM.
- Parameters:
n_epochs (
int(default:100)) – Number of epochs to train.inner_iters (
int(default:2)) – Number of inner iterations to re-estimate γ.kappa (
float(default:0.6)) – Learning rate decay parameter.tau0 (
float(default:1024.0)) – Learning rate offset.lr_mult (float) – Learning rate multiplier.
verbose (bool) – Whether to print progress and metrics.
- SVEM_LDA_Multi.get_cell_topic_dist()#
Get the cell-topic matrix Θ (C × K).
- Return type:
- Returns:
-Θ (
ndarray) Cell-topic matrix, where C is the number of cells and K is the number of topics.
- SVEM_LDA_Multi.get_entropy(normalised=True)#
Compute mean entropy of cell-topic distributions.
Higher entropy means topics are more evenly distributed across cells. This measures the uncertainty in topic assignments per cell.
- SVEM_LDA_Multi.get_feature_topic_dist(modality)#
Get the feature-topic matrix Φ (K × G).
- Parameters:
modality (
str) – The name of the modality for which to retrieve the feature-topic matrix.- Return type:
- Returns:
Φ : np.ndarray or pd.DataFrame Feature-topic matrix, where K is the number of topics and G is the number of features. If the modality has feature names, returns a DataFrame with those names.
- SVEM_LDA_Multi.get_likelihood_per_modality(**kwargs)#
Compute log-likelihood for each modality separately.
Higher is better.
- SVEM_LDA_Multi.get_modality_weights(**kwargs)#
Get normalized mixing weights showing how much each modality contributes to topic assignments.
Only applicable for multimodal models with mixture-of-experts or similar architectures. Returns weights in range [0, 1] that sum to 1 per cell. Higher weight = model relies more on that modality for inferring topics.
- SVEM_LDA_Multi.get_perplexity(**kwargs)#
Compute perplexity (reconstruction quality).
Lower is better. Perplexity = exp(-log_likelihood / N_tokens)
- Return type:
- Returns:
float Perplexity score
- SVEM_LDA_Multi.get_perplexity_per_modality(**kwargs)#
Compute perplexity for each modality separately.
Lower is better. Perplexity = exp(-log_likelihood / N_tokens)
- SVEM_LDA_Multi.get_top_features_per_topic(modality, n_features=10, return_scores=False)#
Get top N features for each topic in a specific modality.
- Parameters:
- Return type:
- Returns:
dict[str, list[str]] or dict[str, list[tuple[str, float]]] Dictionary mapping topic names (e.g., ‘topic_0’) to lists of top feature names or (feature_name, score) tuples.
- SVEM_LDA_Multi.get_topic_diversity(modality=None)#
Compute topic diversity as average pairwise cosine distance.
Higher values indicate more distinct topics. This metric measures how different the topic-feature distributions are from each other.
- Parameters:
modality (
str|None(default:None)) – If provided, compute diversity for this specific modality’s feature-topic distribution. If None, compute diversity averaged across all modalities (default: None).- Return type:
- Returns:
float Average pairwise cosine distance between topic distributions (0-1). Higher = more diverse/distinct topics.
- SVEM_LDA_Multi.predict(data)#
Predict using the fitted model on the provided data.
- Parameters:
data (The input data for prediction.)