Tools (stories.tools)
- class stories.tools.DataLoader(adata: AnnData, time_key: str, omics_key: str, space_key: str, batch_size: int, train_val_split: float, weight_key: str | None = None)
Bases:
objectDataLoader feeds data from an AnnData object to the model as JAX arrays. It samples without replacement for a given batch size.
- Parameters:
adata (AnnData) – The input AnnData object.
time_key (str) – The obs field with float time observations
omics_key (str) – The obsm field with the omics coordinates.
space_key (str) – The obsm field with the spatial coordinates.
batch_size (int) – The batch size.
train_val_split (float, optional) – The proportion of train in the split.
weight_key (str, optional) – The obs field with the marginal weights.
- adata: AnnData
- batch_size: int
- make_train_val_split(key: Array) None
Make a train/validation split. Must be called before training.
- Parameters:
key (PRNGKey) – The random number generator key for permutations.
- next(key: Array, train_or_val: str) Dict[str, Array]
Get the next batch from either train or val indices.
- Parameters:
key (jax.Array) – The random number generator key for sampling.
train_or_val (str) – Either “train” or “val”.
- Returns:
A dictionary of JAX arrays.
- Return type:
Dict[str, jax.Array]
- omics_key: str
- space_key: str
- time_key: str
- train_or_val(iteration: int) bool
Sample whether to train or validate.
- Parameters:
iteration (int) – The current iteration.
- Returns:
True for train, False for val.
- Return type:
bool
- train_val_split: float
- weight_key: str | None = None
- stories.tools.compute_potential(adata: AnnData, model, omics_key: str, key_added: str = 'potential') None
Compute the potential for all cells in an AnnData object.
- Parameters:
adata (AnnData) – Input data
model (SpaceTime) – Trained model
omics_key (str) – The omics key
key_added (str) – The obs key to store the potential. Defaults to “potential”
- stories.tools.compute_velocity(adata: AnnData, model, omics_key: str, key_added: str = 'X_velo') None
Compute -grad J for all cells in an AnnData object, where J is the potential.
- Parameters:
adata (AnnData) – Input data
model (SpaceTime) – Trained model
omics_key (str) – The omics key
key_added (str) – The obsm key to store the potential. Defaults to “X_velo”
- stories.tools.default_checkpoint_manager(absolute_path: str) CheckpointManager
Return a checkpoint manager
- Parameters:
absolute_path (str) – Checkpointing path
- stories.tools.plot_gene_trends(adata, gene_names, potential_key='potential', regression_key='regression', title='')
- stories.tools.plot_losses(model)
- stories.tools.plot_single_gene_trend(adata, gene, potential_key='potential', annotation_key='annotation', regression_key='regression', show_regression=False, **kwargs)
- stories.tools.plot_velocity(adata: AnnData, omics_key: str, basis: str, velocity_key: str = 'X_velo', **kwargs) None
Plot velocity, as computed by compute_velocity
- Parameters:
adata (AnnData) – Input data
omics_key (str) – The obsm key for omics
velocity_key (str) – The obsm key for the velocity
- stories.tools.regress_genes(adata, potential_key='potential', regression_model=None, key_added='regression') None
- stories.tools.select_driver_genes(adata, n_stages: int, n_genes: int, regression_key='regression', remove_ones=True)
- stories.tools.tf_enrich(adata, trrust_path, regression_key='regression', gene_key=None)