Model (stories.spacetime)

class stories.spacetime.SpaceTime(potential: ~flax.linen.module.Module = MLPPotential(     # attributes     features = (128, 128)     activation = gelu ), proximal_step: ~stories.steps.proximal_step.ProximalStep = <stories.steps.explicit.ExplicitStep object>, n_steps: int = 1, teacher_forcing: bool = True, quadratic: bool = True, debias: bool = True, epsilon: float = 0.01, log_callback: ~typing.Callable | None = None, quadratic_weight: float = 0.005)

Bases: object

Wasserstein gradient flow model for spatio-temporal omics data.

Parameters:
  • potential (nn.Module, optional) – The potential function. Defaults to a MLP.

  • proximal_step (ProximalStep, optional) – The proximal step. Defaults forward Euler

  • n_steps (int, optional) – The number of steps. Defaults to 1.

  • teacher_forcing (bool, optional) – Use teacher forcing. Defaults to True.

  • quadratic (bool, optional) – Use a Fused GW loss. Defaults to True.

  • debias (bool, optional) – Whether to debias the loss. Defaults to True.

  • epsilon (float, optional) – The (relative) entropic reg. Defaults to 0.01.

  • log_callback (Callable, optional) – The callback for logging. Defaults to None.

  • quadratic_weight (float, optional) – Weight of the quadratic term, in [0, 1].

debias: bool = True
epsilon: float = 0.01
fit(adata: ~anndata._core.anndata.AnnData, time_key: str, omics_key: str, space_key: str, weight_key: str | None = None, optimizer: ~optax._src.base.GradientTransformation = (<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), max_iter: int = 10000, batch_size: int = 1000, train_val_split: float = 0.75, min_delta: float = 0.0, patience: int = 150, checkpoint_manager: ~orbax.checkpoint.checkpoint_manager.CheckpointManager | str | None = None, key: ~jax.Array = Array([0, 0], dtype=uint32), restore: bool = True) None

Fit the model.

Parameters:
  • adata (AnnData) – The AnnData object.

  • time_key (str) – The name of the time observation.

  • omics_key (str) – The name of the obsm field containing cell coordinates.

  • space_key (str) – The name of the obsm field containing space coordinates.

  • weight_key (str) – The name of the obs field containing weights.

  • optimizer (GradientTransformation, optional) – The optimizer.

  • max_iter (int, optional) – The max number of iterations. Defaults to 10_000.

  • batch_size (int, optional) – The batch size. Defaults to 1_000.

  • train_val_split (float, optional) – The proportion of train in the split.

  • min_delta (float, optional) – The minimum delta for early stopping.

  • patience (int, optional) – The patience for early stopping.

  • checkpoint_manager (CheckpointManager, optional) – Checkpoint manager or path.

  • key (jax.Array, optional) – The random key. Defaults to PRNGKey(0).

  • restore (bool, optional) – By default, load the checkpointed params.

log_callback: Callable | None = None
n_steps: int = 1
potential: Module = MLPPotential(     # attributes     features = (128, 128)     activation = gelu )
proximal_step: ProximalStep = <stories.steps.explicit.ExplicitStep object>
quadratic: bool = True
quadratic_weight: float = 0.005
teacher_forcing: bool = True
transform(adata: AnnData, omics_key: str, tau: float, batch_size: int = 1000, key: Array = Array([0, 0], dtype=uint32)) ndarray

Transform an AnnData object.

Parameters:
  • adata (AnnData) – The AnnData object to transform.

  • omics_key (str) – The obsm field containing the data to transform.

  • tau (float, optional) – The time step.

  • batch_size (int, optional) – The batch size. Defaults to 250.

  • key (jax.Array, optional) – The random key. Defaults to PRNGKey(0).

Returns:

The predictions.

Return type:

np.ndarray