topomics.models.ShareTopic_LDA_Multi#

class topomics.models.ShareTopic_LDA_Multi(mdata, *, modalities=None, n_topics=20, alpha=0.1, gamma=None, tau=None, beta=0.1, protein_likelihood='nb', protein_gamma=None, protein_tau=None, protein_beta=0.1, device=None, smart_init=True, smart_init_mod='rna')#

Multimodal LDA with SHARE-Topic sparse ATAC Gibbs sampler.

Parameters:
  • mdata (MuDataType | dict[str, AnnData] | list[AnnData] | AnnData) – Input data: MuData, dict[str, AnnData], list[AnnData], or single AnnData.

  • modalities (list[str] | None (default: None)) – Modality names when mdata is a list or single AnnData.

  • n_topics (int (default: 20)) – Number of latent topics.

  • alpha (float | list[float] (default: 0.1)) – Dirichlet prior concentration for cell-topic proportions.

  • gamma (float | None (default: None)) – Gamma prior shape/rate for RNA rate parameters. None = empirical Bayes (estimated from data moments).

  • tau (float | None (default: None)) – Gamma prior shape/rate for RNA rate parameters. None = empirical Bayes (estimated from data moments).

  • beta (float (default: 0.1)) – Dirichlet prior concentration for chromatin region proportions.

  • protein_likelihood (str (default: 'nb')) – "nb" (Gamma-Poisson / Negative Binomial) or "multinomial" (Dirichlet-Multinomial) for the protein modality.

  • protein_gamma (float | None (default: None)) – Gamma prior shape/rate for protein NB rates. None = empirical Bayes.

  • protein_tau (float | None (default: None)) – Gamma prior shape/rate for protein NB rates. None = empirical Bayes.

  • protein_beta (float (default: 0.1)) – Dirichlet prior concentration for protein multinomial.

  • device (str | None (default: None)) – "cpu", "cuda", or None (auto-detect).

  • smart_init (bool (default: True)) – Seed parameters via NMF instead of sampling from the prior.

  • smart_init_mod (str (default: 'rna')) – Modality used for NMF initialisation (default "rna").

Attributes table#

Methods table#

check_input(mdata, modalities)

Validate and process the input data.

check_modalities_names()

Standardize and validate modality keys in data_dict.

clear_metric_cache()

Clear the cached metrics.

cross_modality_score(mod_a, mod_b, *[, ...])

Compute SHARE-Topic–style cross-modal interaction matrix P_{a,b}

fit(*[, batch_size, n_samples, thin, ...])

Run the Gibbs sampler.

generate(n_cells[, theta])

Sample synthetic data from the fitted model.

get_cell_topic_dist([normalised])

Posterior mean cell-topic proportions (C × K).

get_entropy([normalised])

Compute mean entropy of cell-topic distributions.

get_feature_topic_dist(modality)

Posterior mean feature-topic matrix (K × Gm).

get_likelihood_per_modality(**kwargs)

Compute log-likelihood for each modality separately.

get_modality_weights(**kwargs)

Not applicable for Gibbs — returns equal weights.

get_perplexity(**kwargs)

exp(-LL / total_counts).

get_perplexity_per_modality(**kwargs)

Compute perplexity for each modality separately.

get_top_features_per_topic(modality[, ...])

Get top N features for each topic in a specific modality.

get_topic_diversity([modality])

Compute topic diversity as average pairwise cosine distance.

load(path, mdata, *[, modalities, device])

Load a saved model.

predict(data)

Predict using the fitted model on the provided data.

save(path)

Save fitted model to a .pt file.

Attributes#

ShareTopic_LDA_Multi.spatial: bool = False#

Methods#

ShareTopic_LDA_Multi.check_input(mdata, modalities)#

Validate and process the input data.

Checks that data are adata or mudata objects, and that the modalities are correctly specified.

ShareTopic_LDA_Multi.check_modalities_names()#

Standardize and validate modality keys in data_dict.

Maps various synonyms to ‘rna’, ‘protein’, or ‘chromatin’, and rebuilds data_dict with standardized keys.

ShareTopic_LDA_Multi.clear_metric_cache()#

Clear the cached metrics.

Call this method after retraining the model to ensure metrics are recomputed with the updated parameters.

ShareTopic_LDA_Multi.cross_modality_score(mod_a, mod_b, *, normalise=True, return_df=True)#

Compute SHARE-Topic–style cross-modal interaction matrix P_{a,b}

Parameters:
  • model (fitted topic model with the two accessors above)

  • mod_a (str)

  • mod_b (str)

  • normalise (bool (default: True))

  • return_df (bool (default: True))

Return type:

ndarray | DataFrame

Returns:

P : shape (n_feat_a, n_feat_b) – interaction score between every feature of mod_a and every feature of mod_b

ShareTopic_LDA_Multi.fit(*, batch_size=512, n_samples=500, thin=10, burnin=1000, initial_burnin=500, progress=True, ll_every=1, auto_burnin=True, min_initial_burnin=50, burnin_window=20, burnin_rtol=0.0001, burnin_patience=3)#

Run the Gibbs sampler.

Parameters:
  • batch_size (int (default: 512)) – Mini-batch size (cells per update).

  • n_samples (int (default: 500)) – Number of retained posterior samples.

  • thin (int (default: 10)) – Keep every thin-th iteration during sampling.

  • burnin (int (default: 1000)) – Discarded iterations after initial burn-in.

  • initial_burnin (int (default: 500)) – Maximum initial equilibration steps (may end early with auto_burnin).

  • progress (bool (default: True)) – Show tqdm progress bars.

  • ll_every (int (default: 1)) – Evaluate joint log-likelihood every ll_every iterations.

  • auto_burnin (bool (default: True)) – Enable automatic thermalization detection during initial burn-in.

  • min_initial_burnin (int (default: 50)) – Minimum iterations before auto burn-in can trigger.

  • burnin_window (int (default: 20)) – Sliding window size for running-mean comparison.

  • burnin_rtol (float (default: 0.0001)) – Relative tolerance for LL stabilisation.

  • burnin_patience (int (default: 3)) – Consecutive passes below burnin_rtol required to stop.

Return type:

None

ShareTopic_LDA_Multi.generate(n_cells, theta=None)#

Sample synthetic data from the fitted model.

Return type:

dict[str, AnnData]

Returns:

dict[str, AnnData] One AnnData per modality with generated counts.

ShareTopic_LDA_Multi.get_cell_topic_dist(normalised=True)#

Posterior mean cell-topic proportions (C × K).

Return type:

ndarray

ShareTopic_LDA_Multi.get_entropy(normalised=True)#

Compute mean entropy of cell-topic distributions.

Higher entropy means topics are more evenly distributed across cells. This measures the uncertainty in topic assignments per cell.

Parameters:

normalised (bool (default: True)) – Whether to normalize cell-topic distributions before computing entropy. If True, ensures distributions sum to 1 (default: True).

Return type:

float

Returns:

float Mean entropy across all cells

ShareTopic_LDA_Multi.get_feature_topic_dist(modality)#

Posterior mean feature-topic matrix (K × Gm).

Return type:

ndarray

ShareTopic_LDA_Multi.get_likelihood_per_modality(**kwargs)#

Compute log-likelihood for each modality separately.

Higher is better.

Return type:

dict[str, float]

Returns:

dict[str, float] Dictionary mapping modality names to log-likelihood values

ShareTopic_LDA_Multi.get_modality_weights(**kwargs)#

Not applicable for Gibbs — returns equal weights.

ShareTopic_LDA_Multi.get_perplexity(**kwargs)#

exp(-LL / total_counts).

Return type:

float

ShareTopic_LDA_Multi.get_perplexity_per_modality(**kwargs)#

Compute perplexity for each modality separately.

Lower is better. Perplexity = exp(-log_likelihood / N_tokens)

Return type:

dict[str, float]

Returns:

dict[str, float] Dictionary mapping modality names to perplexity values

ShareTopic_LDA_Multi.get_top_features_per_topic(modality, n_features=10, return_scores=False)#

Get top N features for each topic in a specific modality.

Parameters:
  • modality (str) – Modality name (e.g., ‘rna’, ‘protein’, ‘chromatin’)

  • n_features (int (default: 10)) – Number of top features to return per topic (default: 10)

  • return_scores (bool (default: False)) – If True, return (feature_name, score) tuples. If False, return feature names only (default: False).

Return type:

dict[str, list[str]] | dict[str, list[tuple[str, float]]]

Returns:

dict[str, list[str]] or dict[str, list[tuple[str, float]]] Dictionary mapping topic names (e.g., ‘topic_0’) to lists of top feature names or (feature_name, score) tuples.

ShareTopic_LDA_Multi.get_topic_diversity(modality=None)#

Compute topic diversity as average pairwise cosine distance.

Higher values indicate more distinct topics. This metric measures how different the topic-feature distributions are from each other.

Parameters:

modality (str | None (default: None)) – If provided, compute diversity for this specific modality’s feature-topic distribution. If None, compute diversity averaged across all modalities (default: None).

Return type:

float

Returns:

float Average pairwise cosine distance between topic distributions (0-1). Higher = more diverse/distinct topics.

classmethod ShareTopic_LDA_Multi.load(path, mdata, *, modalities=None, device=None)#

Load a saved model.

The original data (mdata) must be passed again because we do not serialise the raw count matrices.

Return type:

ShareTopic_LDA_Multi

ShareTopic_LDA_Multi.predict(data)#

Predict using the fitted model on the provided data.

Parameters:

data (The input data for prediction.)

ShareTopic_LDA_Multi.save(path)#

Save fitted model to a .pt file.

Return type:

None