Skip to content

Schedule

StmtDag dataclass

StmtDag(
    id_table: IdTable[Statement] = (
        lambda: idtable.IdTable()
    )(),
    stmts: Dict[str, Statement] = OrderedDict(),
    out_edges: Dict[str, Set[str]] = OrderedDict(),
    inc_edges: Dict[str, Set[str]] = OrderedDict(),
    stmt_index: Dict[Statement, int] = OrderedDict(),
)

Bases: Graph[Statement]

topological_groups

topological_groups()

Split the dag into topological groups where each group contains nodes that have no dependencies on each other, but have dependencies on nodes in one or more previous groups.

Yields:

Type Description

List[str]: A list of node ids in a topological group

Raises:

Type Description
ValueError

If a cyclic dependency is detected

The idea is to yield all nodes with no dependencies, then remove those nodes from the graph repeating until no nodes are left or we reach some upper limit. Worse case is a linear dag, so we can use len(dag.stmts) as the upper limit

If we reach the limit and there are still nodes left, then we have a cyclic dependency.

Source code in .venv/lib/python3.12/site-packages/bloqade/squin/analysis/schedule.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def topological_groups(self):
    """Split the dag into topological groups where each group
    contains nodes that have no dependencies on each other, but
    have dependencies on nodes in one or more previous groups.

    Yields:
        List[str]: A list of node ids in a topological group


    Raises:
        ValueError: If a cyclic dependency is detected


    The idea is to yield all nodes with no dependencies, then remove
    those nodes from the graph repeating until no nodes are left
    or we reach some upper limit. Worse case is a linear dag,
    so we can use len(dag.stmts) as the upper limit

    If we reach the limit and there are still nodes left, then we
    have a cyclic dependency.
    """

    inc_edges = {k: set(v) for k, v in self.inc_edges.items()}

    check_next = inc_edges.keys()

    for _ in range(len(self.stmts)):
        if len(inc_edges) == 0:
            break

        group = [node_id for node_id in check_next if len(inc_edges[node_id]) == 0]
        yield group

        check_next = set()
        for n in group:
            inc_edges.pop(n)
            for m in self.out_edges[n]:
                check_next.add(m)
                inc_edges[m].remove(n)

    if inc_edges:
        raise ValueError("Cyclic dependency detected")