Skip to content

tracksdata.functional

Functional utilities for graph operations.

Classes:

Functions:

TilingScheme dataclass

TilingScheme(
    tile_shape: tuple[S, ...],
    overlap_shape: tuple[S, ...],
    attrs: list[str] | None = None,
)

Tiling scheme for the graph. Graph will be sliced with 'tile_shape' + 2 * 'overlap_shape' per axis.

Parameters:

  • tile_shape

    (tuple[S, ...]) –

    The shape of the tile.

  • overlap_shape

    (tuple[S, ...]) –

    The overlap between tiles PER SIDE.

  • attrs

    (list[str] | None, default: None ) –

    The attributes to include in the tile. If None, all attributes will be included. By default DEFAULT_ATTR_KEYS.T, DEFAULT_ATTR_KEYS.Z, DEFAULT_ATTR_KEYS.Y, DEFAULT_ATTR_KEYS.X are included. If some columns are not present, they will be ignored.

ancestral_connected_edges

ancestral_connected_edges(
    input_graph: BaseGraph,
    reference_graph: BaseGraph,
    match: bool = True,
) -> list[int]

Let an ancestral path be any sequence from (target, source)-edges in the reference_graph. This function returns the subset of edges in the input_graph that are part of an ancestral path in the reference graph.

IMPORTANT: This function updates the input_graph in place when matching with the reference_graph.

Parameters:

  • input_graph

    (BaseGraph) –

    The input graph.

  • reference_graph

    (BaseGraph) –

    The reference graph.

  • match

    (bool, default: True ) –

    Whether to match the input graph with the reference graph.

Source code in src/tracksdata/functional/_labeling.py
def ancestral_connected_edges(
    input_graph: BaseGraph,
    reference_graph: BaseGraph,
    match: bool = True,
) -> list[int]:
    """
    Let an ancestral path be any sequence from (target, source)-edges in the `reference_graph`.
    This function returns the subset of edges in the `input_graph` that are
    part of an ancestral path in the reference graph.

    IMPORTANT: This function updates the `input_graph` in place when matching
    with the `reference_graph`.

    Parameters
    ----------
    input_graph : BaseGraph
        The input graph.
    reference_graph : BaseGraph
        The reference graph.
    match : bool, optional
        Whether to match the input graph with the reference graph.
    """
    if DEFAULT_ATTR_KEYS.TRACKLET_ID not in reference_graph.node_attr_keys():
        tracklet_graph = reference_graph.assign_tracklet_ids()
    else:
        tracklet_graph = reference_graph.tracklet_graph()

    if match:
        input_graph.match(reference_graph)

    elif DEFAULT_ATTR_KEYS.MATCHED_NODE_ID not in input_graph.node_attr_keys():
        raise ValueError(
            "`ancestral_connected_edges` requires the input graph to previously matched "
            f"and have a `{DEFAULT_ATTR_KEYS.MATCHED_NODE_ID}` column when `match=False`"
        )

    in_node_attrs = input_graph.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.MATCHED_NODE_ID])
    ref_node_attrs = reference_graph.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.TRACKLET_ID])

    in_node_attrs = in_node_attrs.filter(
        pl.col(DEFAULT_ATTR_KEYS.MATCHED_NODE_ID) >= 0,
    ).join(
        ref_node_attrs,
        left_on=DEFAULT_ATTR_KEYS.MATCHED_NODE_ID,
        right_on=DEFAULT_ATTR_KEYS.NODE_ID,
        how="left",
    )

    edge_attrs = input_graph.edge_attrs(attr_keys=[])

    edge_attrs = join_node_attrs_to_edges(
        node_attrs=in_node_attrs.select(DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.TRACKLET_ID),
        edge_attrs=edge_attrs,
    )

    tracklet_ancestral_edges = _ancestral_edges(tracklet_graph)
    input_graph_ancestral_edges = _input_graph_ancestral_edges(
        edge_attrs=edge_attrs,
        ancestral_edges=tracklet_ancestral_edges,
    )

    return input_graph_ancestral_edges

apply_tiled

apply_tiled(
    graph: BaseGraph,
    tiling_scheme: TilingScheme,
    func: MapFunc,
    *,
    agg_func: None,
) -> Iterator[T]
apply_tiled(
    graph: BaseGraph,
    tiling_scheme: TilingScheme,
    func: MapFunc,
    *,
    agg_func: ReduceFunc,
) -> R
apply_tiled(
    graph: BaseGraph,
    tiling_scheme: TilingScheme,
    func: MapFunc,
    *,
    agg_func: ReduceFunc | None = None,
) -> Iterator[T] | R

Apply a function to a graph tiled by the tiling scheme. Graph will be sliced with 'tile_shape' + 2 * 'overlap_shape' per axis.

Parameters:

  • graph

    (BaseGraph) –

    The graph to apply the function to.

  • tiling_scheme

    (TilingScheme) –

    The tiling scheme to use.

  • func

    (MapFunc) –

    The function to apply to each tile. It takes two arguments: - filtered_graph_with_overlap: the subgraph inside the tile with the overlap - filtered_graph: the subgraph inside the tile without the overlap If all overlaps are 0, filtered_graph_with_overlap == filtered_graph with minimal overhead.

  • agg_func

    (ReduceFunc | None, default: None ) –

    The function to reduce the results of the function. If None, the results will be yielded.

Returns:

  • Iterator[T] | R

    The results of the function. If agg_func is provided, the results will be reduced. Otherwise, the results will be yielded.

Source code in src/tracksdata/functional/_apply.py
def apply_tiled(
    graph: BaseGraph,
    tiling_scheme: TilingScheme,
    func: MapFunc,
    *,
    agg_func: ReduceFunc | None = None,
) -> Iterator[T] | R:
    """
    Apply a function to a graph tiled by the tiling scheme.
    Graph will be sliced with 'tile_shape' + 2 * 'overlap_shape' per axis.

    Parameters
    ----------
    graph : BaseGraph
        The graph to apply the function to.
    tiling_scheme : TilingScheme
        The tiling scheme to use.
    func : MapFunc
        The function to apply to each tile.
        It takes two arguments:
        - filtered_graph_with_overlap: the subgraph inside the tile with the overlap
        - filtered_graph: the subgraph inside the tile without the overlap
        If all overlaps are 0, filtered_graph_with_overlap == filtered_graph with minimal overhead.

    agg_func : ReduceFunc | None, optional
        The function to reduce the results of the function. If None, the results will be yielded.

    Returns
    -------
    Iterator[T] | R
        The results of the function. If agg_func is provided, the results will be reduced.
        Otherwise, the results will be yielded.
    """
    # this needs to be a separate function because python behave weirdly
    # with functions with both yield and return statements
    res_generator = _yield_apply_tiled(
        graph=graph,
        tiling_scheme=tiling_scheme,
        func=func,
    )

    # if agg_func is provided, we need to reduce the results
    if agg_func is not None:
        return agg_func(res_generator)

    return res_generator

join_node_attrs_to_edges

join_node_attrs_to_edges(
    node_attrs: DataFrame,
    edge_attrs: DataFrame,
    node_id_key: str = DEFAULT_ATTR_KEYS.NODE_ID,
    source_key: str = DEFAULT_ATTR_KEYS.EDGE_SOURCE,
    target_key: str = DEFAULT_ATTR_KEYS.EDGE_TARGET,
    source_prefix: str = "source_",
    target_prefix: str = "target_",
    how: JoinStrategy = "left",
) -> pl.DataFrame

Add node attributes to edge attributes by joining on the node ID.

Parameters:

  • node_attrs

    (DataFrame) –

    Node attributes.

  • edge_attrs

    (DataFrame) –

    Edge attributes.

  • node_id_key

    (str, default: NODE_ID ) –

    The key of the node ID column.

  • source_key

    (str, default: EDGE_SOURCE ) –

    The key of the source column.

  • target_key

    (str, default: EDGE_TARGET ) –

    The key of the target column.

  • source_prefix

    (str, default: 'source_' ) –

    The prefix of the source column.

  • target_prefix

    (str, default: 'target_' ) –

    The prefix of the target column.

  • how

    (JoinStrategy, default: 'left' ) –

    The join type, where the "left" dataframe is the edge attributes and the "right" dataframe is the node attributes.

Returns:

  • DataFrame

    Edge attributes with node attributes added.

Examples:

node_attrs = pl.DataFrame({"node_id": [1, 2, 3], "a": [1, 2, 3], "b": [4, 5, 6]})
edge_attrs = pl.DataFrame({"source": [1, 2], "target": [2, 3]})
node_attr_to_edges(node_attrs, edge_attrs)
# shape: (2, 5)
# ┌──────────┬──────────┬──────────┬──────────┬────────────────┬────────────────┐
# │ source_a ┆ source_b ┆ target_a ┆ target_b ┆ source_node_id ┆ target_node_id │
# │ ---      ┆ ---      ┆ ---      ┆ ---      ┆ ---            ┆ ---            │
# │ i64      ┆ i64      ┆ i64      ┆ i64      ┆ i64            ┆ i64            │
# ╞══════════╪══════════╪══════════╪══════════╪════════════════╪════════════════╡
# │ 1        ┆ 4        ┆ 2        ┆ 5        ┆ 1              ┆ 2              │
# │ 2        ┆ 5        ┆ 3        ┆ 6        ┆ 2              ┆ 3              │
# └──────────┴──────────┴──────────┴──────────┴────────────────┴────────────────┘
Source code in src/tracksdata/functional/_edges.py
def join_node_attrs_to_edges(
    node_attrs: pl.DataFrame,
    edge_attrs: pl.DataFrame,
    node_id_key: str = DEFAULT_ATTR_KEYS.NODE_ID,
    source_key: str = DEFAULT_ATTR_KEYS.EDGE_SOURCE,
    target_key: str = DEFAULT_ATTR_KEYS.EDGE_TARGET,
    source_prefix: str = "source_",
    target_prefix: str = "target_",
    how: JoinStrategy = "left",
) -> pl.DataFrame:
    """
    Add node attributes to edge attributes by joining on the node ID.

    Parameters
    ----------
    node_attrs : pl.DataFrame
        Node attributes.
    edge_attrs : pl.DataFrame
        Edge attributes.
    node_id_key : str, optional
        The key of the node ID column.
    source_key : str, optional
        The key of the source column.
    target_key : str, optional
        The key of the target column.
    source_prefix : str, optional
        The prefix of the source column.
    target_prefix : str, optional
        The prefix of the target column.
    how : JoinStrategy, optional
        The join type, where the "left" dataframe is the edge attributes and
        the "right" dataframe is the node attributes.

    Returns
    -------
    pl.DataFrame
        Edge attributes with node attributes added.

    Examples
    --------
    ```python
    node_attrs = pl.DataFrame({"node_id": [1, 2, 3], "a": [1, 2, 3], "b": [4, 5, 6]})
    edge_attrs = pl.DataFrame({"source": [1, 2], "target": [2, 3]})
    node_attr_to_edges(node_attrs, edge_attrs)
    # shape: (2, 5)
    # ┌──────────┬──────────┬──────────┬──────────┬────────────────┬────────────────┐
    # │ source_a ┆ source_b ┆ target_a ┆ target_b ┆ source_node_id ┆ target_node_id │
    # │ ---      ┆ ---      ┆ ---      ┆ ---      ┆ ---            ┆ ---            │
    # │ i64      ┆ i64      ┆ i64      ┆ i64      ┆ i64            ┆ i64            │
    # ╞══════════╪══════════╪══════════╪══════════╪════════════════╪════════════════╡
    # │ 1        ┆ 4        ┆ 2        ┆ 5        ┆ 1              ┆ 2              │
    # │ 2        ┆ 5        ┆ 3        ┆ 6        ┆ 2              ┆ 3              │
    # └──────────┴──────────┴──────────┴──────────┴────────────────┴────────────────┘
    ```
    """
    node_attr_keys = node_attrs.columns
    node_attr_keys.remove(node_id_key)

    source_mapping = dict(zip(node_attr_keys, [f"{source_prefix}{c}" for c in node_attr_keys], strict=False))
    target_mapping = dict(zip(node_attr_keys, [f"{target_prefix}{c}" for c in node_attr_keys], strict=False))

    edge_attrs = edge_attrs.join(
        node_attrs.select(node_id_key, *node_attr_keys).rename(source_mapping),
        left_on=source_key,
        right_on=node_id_key,
        how=how,
    ).join(
        node_attrs.select(node_id_key, *node_attr_keys).rename(target_mapping),
        left_on=target_key,
        right_on=node_id_key,
        how=how,
    )

    return edge_attrs

rx_digraph_to_napari_dict

rx_digraph_to_napari_dict(
    tracklet_graph: PyDiGraph,
) -> dict[int, list[int]]

Convert a tracklet graph to a napari-ready dictionary. The input is a (child -> parent) graph (forward in time) and it is converted to a (parent -> child) dictionary (backward in time).

Parameters

tracklet_graph : rx.PyDiGraph The tracklet graph to convert.

Returns

dict[int, list[int]] A dictionary of parent -> child relationships.

Source code in src/tracksdata/functional/_napari.py
def rx_digraph_to_napari_dict(
    tracklet_graph: rx.PyDiGraph,
) -> dict[int, list[int]]:
    """
    Convert a tracklet graph to a napari-ready dictionary.
    The input is a (child -> parent) graph (forward in time) and it is converted
    to a (parent -> child) dictionary (backward in time).

    Parameters
    ----------
    tracklet_graph : rx.PyDiGraph
        The tracklet graph to convert.

    Returns
    -------
    dict[int, list[int]]
        A dictionary of parent -> child relationships.
    """
    dict_graph = {}
    for parent, child in tracklet_graph.edges():
        dict_graph.setdefault(child, []).append(parent)
    return dict_graph

to_napari_format

to_napari_format(
    graph: BaseGraph,
    shape: tuple[int, ...] | None,
    solution_key: str | None,
    output_tracklet_id_key: str,
    mask_key: None,
) -> tuple[pl.DataFrame, dict[int, int]]
to_napari_format(
    graph: BaseGraph,
    shape: tuple[int, ...] | None,
    solution_key: str | None,
    output_tracklet_id_key: str,
    mask_key: str,
) -> tuple[pl.DataFrame, dict[int, int], GraphArrayView]
to_napari_format(
    graph: BaseGraph,
    shape: tuple[int, ...] | None = None,
    solution_key: str | None = DEFAULT_ATTR_KEYS.SOLUTION,
    output_tracklet_id_key: str = DEFAULT_ATTR_KEYS.TRACKLET_ID,
    mask_key: str | None = None,
    chunk_shape: tuple[int] | None = None,
    buffer_cache_size: int | None = None,
) -> (
    tuple[pl.DataFrame, dict[int, int], GraphArrayView]
    | tuple[pl.DataFrame, dict[int, int]]
)

Convert the subgraph of solution nodes to a napari-ready format.

This includes: - a tracks layer with the solution tracks - a graph with the parent-child relationships for the solution tracks - a labels layer with the solution nodes if mask_key is provided.

IMPORTANT: This function will reset the track ids if they already exist.

Parameters:

  • graph

    (BaseGraph) –

    The graph to convert.

  • shape

    (tuple[int, ...] | None, default: None ) –

    The shape of the labels layer. If None, the shape is inferred from the graph metadata shape key.

  • solution_key

    (str, default: SOLUTION ) –

    The key of the solution attribute. If None, the graph is not filtered by the solution attribute.

  • output_tracklet_id_key

    (str, default: TRACKLET_ID ) –

    The key of the output track id attribute.

  • mask_key

    (str | None, default: None ) –

    The key of the mask attribute.

  • chunk_shape

    (tuple[int] | None, default: None ) –

    The chunk shape for the labels layer. If None, the default chunk size is used.

  • buffer_cache_size

    (int, default: None ) –

    The maximum number of buffers to keep in the cache for the labels layer. If None, the default buffer cache size is used.

Examples:

labels = ...
graph = ...
tracks_data, dict_graph, array_view = to_napari_format(graph, labels.shape, mask_key="mask")

Returns:

  • tuple[DataFrame, dict[int, int], GraphArrayView] | tuple[DataFrame, dict[int, int]]
    • tracks_data: The tracks data as a polars DataFrame.
    • dict_graph: A dictionary of parent -> child relationships.
    • array_view: The array view of the solution graph if mask_key is provided.
Source code in src/tracksdata/functional/_napari.py
def to_napari_format(
    graph: BaseGraph,
    shape: tuple[int, ...] | None = None,
    solution_key: str | None = DEFAULT_ATTR_KEYS.SOLUTION,
    output_tracklet_id_key: str = DEFAULT_ATTR_KEYS.TRACKLET_ID,
    mask_key: str | None = None,
    chunk_shape: tuple[int] | None = None,
    buffer_cache_size: int | None = None,
) -> (
    tuple[
        pl.DataFrame,
        dict[int, int],
        "GraphArrayView",
    ]
    | tuple[
        pl.DataFrame,
        dict[int, int],
    ]
):
    """
    Convert the subgraph of solution nodes to a napari-ready format.

    This includes:
    - a tracks layer with the solution tracks
    - a graph with the parent-child relationships for the solution tracks
    - a labels layer with the solution nodes if `mask_key` is provided.

    IMPORTANT: This function will reset the track ids if they already exist.

    Parameters
    ----------
    graph : BaseGraph
        The graph to convert.
    shape : tuple[int, ...] | None, optional
        The shape of the labels layer. If None, the shape is inferred from the graph metadata `shape` key.
    solution_key : str, optional
        The key of the solution attribute. If None, the graph is not filtered by the solution attribute.
    output_tracklet_id_key : str, optional
        The key of the output track id attribute.
    mask_key : str | None, optional
        The key of the mask attribute.
    chunk_shape : tuple[int] | None, optional
        The chunk shape for the labels layer. If None, the default chunk size is used.
    buffer_cache_size : int, optional
        The maximum number of buffers to keep in the cache for the labels layer.
        If None, the default buffer cache size is used.

    Examples
    --------

    ```python
    labels = ...
    graph = ...
    tracks_data, dict_graph, array_view = to_napari_format(graph, labels.shape, mask_key="mask")
    ```

    Returns
    -------
    tuple[pl.DataFrame, dict[int, int], GraphArrayView] | tuple[pl.DataFrame, dict[int, int]]
        - tracks_data: The tracks data as a polars DataFrame.
        - dict_graph: A dictionary of parent -> child relationships.
        - array_view: The array view of the solution graph if `mask_key` is provided.
    """
    if solution_key is not None:
        solution_graph = graph.filter(
            NodeAttr(solution_key) == True,
            EdgeAttr(solution_key) == True,
        ).subgraph()

    else:
        solution_graph = graph

    shape = _validate_shape(shape, solution_graph, "to_napari_format")

    tracks_graph = solution_graph.assign_tracklet_ids(output_tracklet_id_key)
    dict_graph = {tracks_graph[child]: tracks_graph[parent] for parent, child in tracks_graph.edge_list()}

    spatial_cols = [DEFAULT_ATTR_KEYS.Z, DEFAULT_ATTR_KEYS.Y, DEFAULT_ATTR_KEYS.X][-len(shape) + 1 :]

    tracks_data = solution_graph.node_attrs(
        attr_keys=[output_tracklet_id_key, DEFAULT_ATTR_KEYS.T, *spatial_cols],
    )

    # sorting columns
    tracks_data = tracks_data.select([output_tracklet_id_key, DEFAULT_ATTR_KEYS.T, *spatial_cols])

    if mask_key is not None:
        from tracksdata.array._graph_array import GraphArrayView

        array_view = GraphArrayView(
            solution_graph,
            shape=shape,
            attr_key=output_tracklet_id_key,
            chunk_shape=chunk_shape,
            buffer_cache_size=buffer_cache_size,
        )

        return tracks_data, dict_graph, array_view

    return tracks_data, dict_graph