Tools (stories.tools)

class stories.tools.DataLoader(adata: AnnData, time_key: str, omics_key: str, space_key: str, batch_size: int, train_val_split: float, weight_key: str | None = None)

Bases: object

DataLoader feeds data from an AnnData object to the model as JAX arrays. It samples without replacement for a given batch size.

Parameters:
  • adata (AnnData) – The input AnnData object.

  • time_key (str) – The obs field with float time observations

  • omics_key (str) – The obsm field with the omics coordinates.

  • space_key (str) – The obsm field with the spatial coordinates.

  • batch_size (int) – The batch size.

  • train_val_split (float, optional) – The proportion of train in the split.

  • weight_key (str, optional) – The obs field with the marginal weights.

adata: AnnData
batch_size: int
make_train_val_split(key: Array) None

Make a train/validation split. Must be called before training.

Parameters:

key (PRNGKey) – The random number generator key for permutations.

next(key: Array, train_or_val: str) Dict[str, Array]

Get the next batch from either train or val indices.

Parameters:
  • key (jax.Array) – The random number generator key for sampling.

  • train_or_val (str) – Either “train” or “val”.

Returns:

A dictionary of JAX arrays.

Return type:

Dict[str, jax.Array]

omics_key: str
space_key: str
time_key: str
train_or_val(iteration: int) bool

Sample whether to train or validate.

Parameters:

iteration (int) – The current iteration.

Returns:

True for train, False for val.

Return type:

bool

train_val_split: float
weight_key: str | None = None
stories.tools.compute_potential(adata: AnnData, model, omics_key: str, key_added: str = 'potential') None

Compute the potential for all cells in an AnnData object.

Parameters:
  • adata (AnnData) – Input data

  • model (SpaceTime) – Trained model

  • omics_key (str) – The omics key

  • key_added (str) – The obs key to store the potential. Defaults to “potential”

stories.tools.compute_velocity(adata: AnnData, model, omics_key: str, key_added: str = 'X_velo') None

Compute -grad J for all cells in an AnnData object, where J is the potential.

Parameters:
  • adata (AnnData) – Input data

  • model (SpaceTime) – Trained model

  • omics_key (str) – The omics key

  • key_added (str) – The obsm key to store the potential. Defaults to “X_velo”

stories.tools.default_checkpoint_manager(absolute_path: str) CheckpointManager

Return a checkpoint manager

Parameters:

absolute_path (str) – Checkpointing path

stories.tools.plot_losses(model)
stories.tools.plot_single_gene_trend(adata, gene, potential_key='potential', annotation_key='annotation', regression_key='regression', show_regression=False, **kwargs)
stories.tools.plot_velocity(adata: AnnData, omics_key: str, basis: str, velocity_key: str = 'X_velo', **kwargs) None

Plot velocity, as computed by compute_velocity

Parameters:
  • adata (AnnData) – Input data

  • omics_key (str) – The obsm key for omics

  • velocity_key (str) – The obsm key for the velocity

stories.tools.regress_genes(adata, potential_key='potential', regression_model=None, key_added='regression') None
stories.tools.select_driver_genes(adata, n_stages: int, n_genes: int, regression_key='regression', remove_ones=True)
stories.tools.tf_enrich(adata, trrust_path, regression_key='regression', gene_key=None)