Model (stories.spacetime)
- class stories.spacetime.SpaceTime(potential: ~flax.linen.module.Module = MLPPotential( # attributes features = (128, 128) activation = gelu ), proximal_step: ~stories.steps.proximal_step.ProximalStep = <stories.steps.explicit.ExplicitStep object>, n_steps: int = 1, teacher_forcing: bool = True, quadratic: bool = True, debias: bool = True, epsilon: float = 0.01, log_callback: ~typing.Callable | None = None, quadratic_weight: float = 0.005)
Bases:
objectWasserstein gradient flow model for spatio-temporal omics data.
- Parameters:
potential (nn.Module, optional) – The potential function. Defaults to a MLP.
proximal_step (ProximalStep, optional) – The proximal step. Defaults forward Euler
n_steps (int, optional) – The number of steps. Defaults to 1.
teacher_forcing (bool, optional) – Use teacher forcing. Defaults to True.
quadratic (bool, optional) – Use a Fused GW loss. Defaults to True.
debias (bool, optional) – Whether to debias the loss. Defaults to True.
epsilon (float, optional) – The (relative) entropic reg. Defaults to 0.01.
log_callback (Callable, optional) – The callback for logging. Defaults to None.
quadratic_weight (float, optional) – Weight of the quadratic term, in [0, 1].
- debias: bool = True
- epsilon: float = 0.01
- fit(adata: ~anndata._core.anndata.AnnData, time_key: str, omics_key: str, space_key: str, weight_key: str | None = None, optimizer: ~optax._src.base.GradientTransformation = (<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), max_iter: int = 10000, batch_size: int = 1000, train_val_split: float = 0.75, min_delta: float = 0.0, patience: int = 150, checkpoint_manager: ~orbax.checkpoint.checkpoint_manager.CheckpointManager | str | None = None, key: ~jax.Array = Array([0, 0], dtype=uint32), restore: bool = True) None
Fit the model.
- Parameters:
adata (AnnData) – The AnnData object.
time_key (str) – The name of the time observation.
omics_key (str) – The name of the obsm field containing cell coordinates.
space_key (str) – The name of the obsm field containing space coordinates.
weight_key (str) – The name of the obs field containing weights.
optimizer (GradientTransformation, optional) – The optimizer.
max_iter (int, optional) – The max number of iterations. Defaults to 10_000.
batch_size (int, optional) – The batch size. Defaults to 1_000.
train_val_split (float, optional) – The proportion of train in the split.
min_delta (float, optional) – The minimum delta for early stopping.
patience (int, optional) – The patience for early stopping.
checkpoint_manager (CheckpointManager, optional) – Checkpoint manager or path.
key (jax.Array, optional) – The random key. Defaults to PRNGKey(0).
restore (bool, optional) – By default, load the checkpointed params.
- log_callback: Callable | None = None
- n_steps: int = 1
- potential: Module = MLPPotential( # attributes features = (128, 128) activation = gelu )
- proximal_step: ProximalStep = <stories.steps.explicit.ExplicitStep object>
- quadratic: bool = True
- quadratic_weight: float = 0.005
- teacher_forcing: bool = True
- transform(adata: AnnData, omics_key: str, tau: float, batch_size: int = 1000, key: Array = Array([0, 0], dtype=uint32)) ndarray
Transform an AnnData object.
- Parameters:
adata (AnnData) – The AnnData object to transform.
omics_key (str) – The obsm field containing the data to transform.
tau (float, optional) – The time step.
batch_size (int, optional) – The batch size. Defaults to 250.
key (jax.Array, optional) – The random key. Defaults to PRNGKey(0).
- Returns:
The predictions.
- Return type:
np.ndarray