Skip to content

API Reference

Public API

swcviz package scaffolding.

Public API is evolving; currently exposes SWC parsing utilities and models.

SWCRecord dataclass

One SWC row.

Attributes

n: int Node id (unique within file) t: int Structure type code x, y, z: float Coordinates (usually micrometers) r: float Radius parent: int Parent id; -1 indicates root line: int 1-based line number in the source file/string

Source code in swcviz/io.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
@dataclass(frozen=True)
class SWCRecord:
    """One SWC row.

    Attributes
    ----------
    n: int
        Node id (unique within file)
    t: int
        Structure type code
    x, y, z: float
        Coordinates (usually micrometers)
    r: float
        Radius
    parent: int
        Parent id; -1 indicates root
    line: int
        1-based line number in the source file/string
    """

    n: int
    t: int
    x: float
    y: float
    z: float
    r: float
    parent: int
    line: int

SWCParseResult dataclass

Parsed SWC content.

Source code in swcviz/io.py
72
73
74
75
76
77
78
79
80
81
82
83
84
@dataclass(frozen=True)
class SWCParseResult:
    """Parsed SWC content."""

    records: Dict[int, SWCRecord]
    reconnections: List[Tuple[int, int]]
    comments: List[str]

    def __str__(self) -> str:
        return f"SWCParseResult(records={len(self.records)}, reconnections={len(self.reconnections)}, comments={len(self.comments)})"

    def __repr__(self) -> str:
        return str(self)

SWCModel

Bases: DiGraph

Directed SWC morphology graph.

Nodes are keyed by SWC id n and store attributes: - t: int (structure type) - x, y, z: float (coordinates) - r: float (radius) - line: int (line number in source; informational)

Edges are directed parent -> child.

Source code in swcviz/model.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
class SWCModel(nx.DiGraph):
    """Directed SWC morphology graph.

    Nodes are keyed by SWC id `n` and store attributes:
    - t: int (structure type)
    - x, y, z: float (coordinates)
    - r: float (radius)
    - line: int (line number in source; informational)

    Edges are directed parent -> child.
    """

    def __init__(self) -> None:
        # Initialize as a plain DiGraph; we don't need multigraph features.
        super().__init__()

    # ----------------------------------------------------------------------------------------------
    # Construction helpers
    # ----------------------------------------------------------------------------------------------
    @classmethod
    def from_parse_result(cls, result: SWCParseResult) -> "SWCModel":
        """Build a model from a parsed SWC result."""
        return cls.from_records(result.records)

    @classmethod
    def from_records(
        cls, records: Mapping[int, SWCRecord] | Iterable[SWCRecord]
    ) -> "SWCModel":
        """Build a model from SWC records.

        Accepts either a mapping of id->record or any iterable of SWCRecord.
        """
        model = cls()

        # Materialize to a list once so we can iterate twice safely
        if isinstance(records, Mapping):
            rec_values = list(records.values())
        else:
            rec_values = list(records)

        # First pass: add all nodes with attributes
        for rec in rec_values:
            model.add_node(
                rec.n,
                t=rec.t,
                x=rec.x,
                y=rec.y,
                z=rec.z,
                r=rec.r,
                line=rec.line,
            )

        # Second pass: add edges parent -> child
        for rec in rec_values:
            if rec.parent != -1:
                model.add_edge(rec.parent, rec.n)

        return model

    @classmethod
    def from_swc_file(
        cls,
        source: str | os.PathLike[str] | Iterable[str],
        *,
        strict: bool = True,
        validate_reconnections: bool = True,
        float_tol: float = 1e-9,
    ) -> "SWCModel":
        """Parse an SWC source then build a model.

        The `source` is passed through to `parse_swc`, which supports a path,
        a file-like object, a string with the full contents, or an iterable of lines.
        """
        result = parse_swc(
            source,
            strict=strict,
            validate_reconnections=validate_reconnections,
            float_tol=float_tol,
        )
        return cls.from_parse_result(result)

    # ----------------------------------------------------------------------------------------------
    # Convenience queries
    # ----------------------------------------------------------------------------------------------
    def roots(self) -> list[int]:
        """Return nodes with in-degree 0 (forest roots)."""
        return [n for n, deg in self.in_degree() if deg == 0]

    def parent_of(self, n: int) -> int | None:
        """Return the parent id of node n, or None if n is a root.

        SWC trees should have at most one parent per node; if multiple are found
        this indicates invalid structure for SWC and an error is raised.
        """
        preds = list(self.predecessors(n))
        if not preds:
            return None
        if len(preds) > 1:
            raise ValueError(
                f"Node {n} has multiple parents in SWCModel; expected a tree/forest"
            )
        return preds[0]

    def path_to_root(self, n: int) -> list[int]:
        """Return the path from node n up to its root, inclusive.

        Example: For edges 1->2->3, `path_to_root(3)` returns `[3, 2, 1]`.
        """
        path: list[int] = [n]
        current = n
        while True:
            p = self.parent_of(current)
            if p is None:
                break
            path.append(p)
            current = p
        return path

    def print_attributes(self, *, node_info: bool = False, edge_info: bool = False) -> None:
        """Print graph attributes and optional node/edge details.

        Parameters
        ----------
        node_info: bool
            If True, print per-node attributes (t, x, y, z, r, line where present).
        edge_info: bool
            If True, print all edges (u -> v) with edge attributes if any.
        """
        info = _graph_attributes(self)
        header = (
            f"SWCModel: nodes={info['nodes']}, edges={info['edges']}, "
            f"components={info['components']}, cycles={info['cycles']}, "
            f"branch_points={info['branch_points_count']}, roots={info['roots_count']}, "
            f"leaves={info['leaves_count']}, self_loops={info['self_loops']}, density={info['density']:.4f}"
        )
        print(header)

        if node_info:
            print("Nodes:")
            ordered = ["t", "x", "y", "z", "r", "line"]
            for n, attrs in self.nodes(data=True):
                parts = [f"{k}={attrs[k]}" for k in ordered if k in attrs]
                print(f"  {n}: " + ", ".join(parts))

        if edge_info:
            print("Edges:")
            for u, v, attrs in self.edges(data=True):
                if attrs:
                    print(f"  {u} -> {v}: {dict(attrs)}")
                else:
                    print(f"  {u} -> {v}")

from_parse_result(result) classmethod

Build a model from a parsed SWC result.

Source code in swcviz/model.py
 98
 99
100
101
@classmethod
def from_parse_result(cls, result: SWCParseResult) -> "SWCModel":
    """Build a model from a parsed SWC result."""
    return cls.from_records(result.records)

from_records(records) classmethod

Build a model from SWC records.

Accepts either a mapping of id->record or any iterable of SWCRecord.

Source code in swcviz/model.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
@classmethod
def from_records(
    cls, records: Mapping[int, SWCRecord] | Iterable[SWCRecord]
) -> "SWCModel":
    """Build a model from SWC records.

    Accepts either a mapping of id->record or any iterable of SWCRecord.
    """
    model = cls()

    # Materialize to a list once so we can iterate twice safely
    if isinstance(records, Mapping):
        rec_values = list(records.values())
    else:
        rec_values = list(records)

    # First pass: add all nodes with attributes
    for rec in rec_values:
        model.add_node(
            rec.n,
            t=rec.t,
            x=rec.x,
            y=rec.y,
            z=rec.z,
            r=rec.r,
            line=rec.line,
        )

    # Second pass: add edges parent -> child
    for rec in rec_values:
        if rec.parent != -1:
            model.add_edge(rec.parent, rec.n)

    return model

from_swc_file(source, *, strict=True, validate_reconnections=True, float_tol=1e-09) classmethod

Parse an SWC source then build a model.

The source is passed through to parse_swc, which supports a path, a file-like object, a string with the full contents, or an iterable of lines.

Source code in swcviz/model.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
@classmethod
def from_swc_file(
    cls,
    source: str | os.PathLike[str] | Iterable[str],
    *,
    strict: bool = True,
    validate_reconnections: bool = True,
    float_tol: float = 1e-9,
) -> "SWCModel":
    """Parse an SWC source then build a model.

    The `source` is passed through to `parse_swc`, which supports a path,
    a file-like object, a string with the full contents, or an iterable of lines.
    """
    result = parse_swc(
        source,
        strict=strict,
        validate_reconnections=validate_reconnections,
        float_tol=float_tol,
    )
    return cls.from_parse_result(result)

roots()

Return nodes with in-degree 0 (forest roots).

Source code in swcviz/model.py
163
164
165
def roots(self) -> list[int]:
    """Return nodes with in-degree 0 (forest roots)."""
    return [n for n, deg in self.in_degree() if deg == 0]

parent_of(n)

Return the parent id of node n, or None if n is a root.

SWC trees should have at most one parent per node; if multiple are found this indicates invalid structure for SWC and an error is raised.

Source code in swcviz/model.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def parent_of(self, n: int) -> int | None:
    """Return the parent id of node n, or None if n is a root.

    SWC trees should have at most one parent per node; if multiple are found
    this indicates invalid structure for SWC and an error is raised.
    """
    preds = list(self.predecessors(n))
    if not preds:
        return None
    if len(preds) > 1:
        raise ValueError(
            f"Node {n} has multiple parents in SWCModel; expected a tree/forest"
        )
    return preds[0]

path_to_root(n)

Return the path from node n up to its root, inclusive.

Example: For edges 1->2->3, path_to_root(3) returns [3, 2, 1].

Source code in swcviz/model.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def path_to_root(self, n: int) -> list[int]:
    """Return the path from node n up to its root, inclusive.

    Example: For edges 1->2->3, `path_to_root(3)` returns `[3, 2, 1]`.
    """
    path: list[int] = [n]
    current = n
    while True:
        p = self.parent_of(current)
        if p is None:
            break
        path.append(p)
        current = p
    return path

print_attributes(*, node_info=False, edge_info=False)

Print graph attributes and optional node/edge details.

Parameters

node_info: bool If True, print per-node attributes (t, x, y, z, r, line where present). edge_info: bool If True, print all edges (u -> v) with edge attributes if any.

Source code in swcviz/model.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def print_attributes(self, *, node_info: bool = False, edge_info: bool = False) -> None:
    """Print graph attributes and optional node/edge details.

    Parameters
    ----------
    node_info: bool
        If True, print per-node attributes (t, x, y, z, r, line where present).
    edge_info: bool
        If True, print all edges (u -> v) with edge attributes if any.
    """
    info = _graph_attributes(self)
    header = (
        f"SWCModel: nodes={info['nodes']}, edges={info['edges']}, "
        f"components={info['components']}, cycles={info['cycles']}, "
        f"branch_points={info['branch_points_count']}, roots={info['roots_count']}, "
        f"leaves={info['leaves_count']}, self_loops={info['self_loops']}, density={info['density']:.4f}"
    )
    print(header)

    if node_info:
        print("Nodes:")
        ordered = ["t", "x", "y", "z", "r", "line"]
        for n, attrs in self.nodes(data=True):
            parts = [f"{k}={attrs[k]}" for k in ordered if k in attrs]
            print(f"  {n}: " + ", ".join(parts))

    if edge_info:
        print("Edges:")
        for u, v, attrs in self.edges(data=True):
            if attrs:
                print(f"  {u} -> {v}: {dict(attrs)}")
            else:
                print(f"  {u} -> {v}")

GeneralModel

Bases: Graph

Undirected morphology graph with reconnection merges.

  • Subclasses networkx.Graph.
  • Nodes correspond to merged SWC points according to header annotations # CYCLE_BREAK reconnect i j.
  • Node attributes include: x, y, z, r (identical across merged ids), representative n, optional t, and provenance lists merged_ids, lines.
  • Edges are undirected between merged nodes; self-loops are skipped if parent/child collapse into the same merged node.
Source code in swcviz/model.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
class GeneralModel(nx.Graph):
    """Undirected morphology graph with reconnection merges.

    - Subclasses `networkx.Graph`.
    - Nodes correspond to merged SWC points according to header annotations
      `# CYCLE_BREAK reconnect i j`.
    - Node attributes include: `x, y, z, r` (identical across merged ids),
      representative `n`, optional `t`, and provenance lists `merged_ids`, `lines`.
    - Edges are undirected between merged nodes; self-loops are skipped if
      parent/child collapse into the same merged node.
    """

    def __init__(self) -> None:
        super().__init__()

    # ------------------------------------------------------------------------------------------
    # Construction helpers
    # ------------------------------------------------------------------------------------------
    @classmethod
    def from_parse_result(
        cls,
        result: SWCParseResult,
        *,
        validate_reconnections: bool = True,
        float_tol: float = 1e-9,
    ) -> "GeneralModel":
        """Build a merged undirected model from a parsed SWC result.

        If `validate_reconnections` is True, enforce identical (x, y, z, r)
        for each reconnect pair before merging (useful when `parse_swc` was
        called with validation disabled).
        """
        # Materialize record mapping
        records = result.records

        # ---- Union-Find for merges -------------------------------------------------------------
        parent: dict[int, int] = {}
        rank: dict[int, int] = {}

        def uf_find(a: int) -> int:
            # Path compression
            pa = parent.get(a, a)
            if pa != a:
                parent[a] = uf_find(pa)
            else:
                parent.setdefault(a, a)
                rank.setdefault(a, 0)
            return parent[a]

        def identical_xyzr(a: SWCRecord, b: SWCRecord) -> bool:
            return (
                abs(a.x - b.x) <= float_tol
                and abs(a.y - b.y) <= float_tol
                and abs(a.z - b.z) <= float_tol
                and abs(a.r - b.r) <= float_tol
            )

        def uf_union(a: int, b: int) -> None:
            ra, rb = uf_find(a), uf_find(b)
            if ra == rb:
                return
            # Union by rank; tie-breaker on smaller id for stability
            rra, rrb = rank.get(ra, 0), rank.get(rb, 0)
            if rra < rrb or (rra == rrb and ra > rb):
                ra, rb = rb, ra
                rra, rrb = rrb, rra
            parent[rb] = ra
            rank[ra] = max(rra, rrb + 1)

        # Seed UF with all ids
        for n in records.keys():
            parent[n] = n
            rank[n] = 0

        # Apply merges from reconnection annotations
        for i, j in result.reconnections:
            if i not in records or j not in records:
                raise ValueError(
                    f"Reconnection pair ({i}, {j}) refers to undefined node id(s)"
                )
            if validate_reconnections:
                if not identical_xyzr(records[i], records[j]):
                    raise ValueError(
                        "Reconnection requires identical (x, y, z, r) but got:\n"
                        f"  {i}: (x={records[i].x}, y={records[i].y}, z={records[i].z}, r={records[i].r})\n"
                        f"  {j}: (x={records[j].x}, y={records[j].y}, z={records[j].z}, r={records[j].r})"
                    )
            uf_union(i, j)

        # Build groups by representative
        groups: dict[int, list[int]] = {}
        for n in records.keys():
            r = uf_find(n)
            groups.setdefault(r, []).append(n)

        # Create the Graph nodes with merged attributes
        model = cls()
        for rep, ids in groups.items():
            # Sort ids for stable ordering and reproducibility
            ids_sorted = sorted(ids)
            first = records[ids_sorted[0]]
            # Attributes are taken from the first (coordinates identical by contract)
            attrs = {
                "n": ids_sorted[0],
                "x": first.x,
                "y": first.y,
                "z": first.z,
                "r": first.r,
                # Representative type; may vary across merged ids, but keep one for convenience
                "t": first.t,
                # Provenance
                "merged_ids": ids_sorted,
                "lines": sorted(records[i].line for i in ids_sorted),
            }
            model.add_node(rep, **attrs)

        # Add undirected edges between merged representatives (skip self-loops)
        for rec in records.values():
            if rec.parent == -1:
                continue
            u = uf_find(rec.parent)
            v = uf_find(rec.n)
            if u != v:
                model.add_edge(u, v)

        return model

    @classmethod
    def from_swc_file(
        cls,
        source: str | os.PathLike[str] | Iterable[str],
        *,
        strict: bool = True,
        validate_reconnections: bool = True,
        float_tol: float = 1e-9,
    ) -> "GeneralModel":
        """Parse an SWC source and build a merged undirected model."""
        result = parse_swc(
            source,
            strict=strict,
            validate_reconnections=validate_reconnections,
            float_tol=float_tol,
        )
        return cls.from_parse_result(
            result,
            validate_reconnections=validate_reconnections,
            float_tol=float_tol,
        )

    def print_attributes(self, *, node_info: bool = False, edge_info: bool = False) -> None:
        """Print graph attributes and optional node/edge details.

        Parameters
        ----------
        node_info: bool
            If True, print per-node attributes (n, x, y, z, r, t, merged_ids, lines where present).
        edge_info: bool
            If True, print all edges (u -- v) with edge attributes if any.
        """
        info = _graph_attributes(self)
        header = (
            f"GeneralModel: nodes={info['nodes']}, edges={info['edges']}, "
            f"components={info['components']}, cycles={info['cycles']}, "
            f"branch_points={info['branch_points_count']}, leaves={info['leaves_count']}, "
            f"self_loops={info['self_loops']}, density={info['density']:.4f}"
        )
        print(header)

        if node_info:
            print("Nodes:")
            ordered = ["n", "x", "y", "z", "r", "t", "merged_ids", "lines"]
            for n, attrs in self.nodes(data=True):
                parts = [f"{k}={attrs[k]}" for k in ordered if k in attrs]
                print(f"  {n}: " + ", ".join(parts))

        if edge_info:
            print("Edges:")
            for u, v, attrs in self.edges(data=True):
                if attrs:
                    print(f"  {u} -- {v}: {dict(attrs)}")
                else:
                    print(f"  {u} -- {v}")

from_parse_result(result, *, validate_reconnections=True, float_tol=1e-09) classmethod

Build a merged undirected model from a parsed SWC result.

If validate_reconnections is True, enforce identical (x, y, z, r) for each reconnect pair before merging (useful when parse_swc was called with validation disabled).

Source code in swcviz/model.py
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
@classmethod
def from_parse_result(
    cls,
    result: SWCParseResult,
    *,
    validate_reconnections: bool = True,
    float_tol: float = 1e-9,
) -> "GeneralModel":
    """Build a merged undirected model from a parsed SWC result.

    If `validate_reconnections` is True, enforce identical (x, y, z, r)
    for each reconnect pair before merging (useful when `parse_swc` was
    called with validation disabled).
    """
    # Materialize record mapping
    records = result.records

    # ---- Union-Find for merges -------------------------------------------------------------
    parent: dict[int, int] = {}
    rank: dict[int, int] = {}

    def uf_find(a: int) -> int:
        # Path compression
        pa = parent.get(a, a)
        if pa != a:
            parent[a] = uf_find(pa)
        else:
            parent.setdefault(a, a)
            rank.setdefault(a, 0)
        return parent[a]

    def identical_xyzr(a: SWCRecord, b: SWCRecord) -> bool:
        return (
            abs(a.x - b.x) <= float_tol
            and abs(a.y - b.y) <= float_tol
            and abs(a.z - b.z) <= float_tol
            and abs(a.r - b.r) <= float_tol
        )

    def uf_union(a: int, b: int) -> None:
        ra, rb = uf_find(a), uf_find(b)
        if ra == rb:
            return
        # Union by rank; tie-breaker on smaller id for stability
        rra, rrb = rank.get(ra, 0), rank.get(rb, 0)
        if rra < rrb or (rra == rrb and ra > rb):
            ra, rb = rb, ra
            rra, rrb = rrb, rra
        parent[rb] = ra
        rank[ra] = max(rra, rrb + 1)

    # Seed UF with all ids
    for n in records.keys():
        parent[n] = n
        rank[n] = 0

    # Apply merges from reconnection annotations
    for i, j in result.reconnections:
        if i not in records or j not in records:
            raise ValueError(
                f"Reconnection pair ({i}, {j}) refers to undefined node id(s)"
            )
        if validate_reconnections:
            if not identical_xyzr(records[i], records[j]):
                raise ValueError(
                    "Reconnection requires identical (x, y, z, r) but got:\n"
                    f"  {i}: (x={records[i].x}, y={records[i].y}, z={records[i].z}, r={records[i].r})\n"
                    f"  {j}: (x={records[j].x}, y={records[j].y}, z={records[j].z}, r={records[j].r})"
                )
        uf_union(i, j)

    # Build groups by representative
    groups: dict[int, list[int]] = {}
    for n in records.keys():
        r = uf_find(n)
        groups.setdefault(r, []).append(n)

    # Create the Graph nodes with merged attributes
    model = cls()
    for rep, ids in groups.items():
        # Sort ids for stable ordering and reproducibility
        ids_sorted = sorted(ids)
        first = records[ids_sorted[0]]
        # Attributes are taken from the first (coordinates identical by contract)
        attrs = {
            "n": ids_sorted[0],
            "x": first.x,
            "y": first.y,
            "z": first.z,
            "r": first.r,
            # Representative type; may vary across merged ids, but keep one for convenience
            "t": first.t,
            # Provenance
            "merged_ids": ids_sorted,
            "lines": sorted(records[i].line for i in ids_sorted),
        }
        model.add_node(rep, **attrs)

    # Add undirected edges between merged representatives (skip self-loops)
    for rec in records.values():
        if rec.parent == -1:
            continue
        u = uf_find(rec.parent)
        v = uf_find(rec.n)
        if u != v:
            model.add_edge(u, v)

    return model

from_swc_file(source, *, strict=True, validate_reconnections=True, float_tol=1e-09) classmethod

Parse an SWC source and build a merged undirected model.

Source code in swcviz/model.py
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
@classmethod
def from_swc_file(
    cls,
    source: str | os.PathLike[str] | Iterable[str],
    *,
    strict: bool = True,
    validate_reconnections: bool = True,
    float_tol: float = 1e-9,
) -> "GeneralModel":
    """Parse an SWC source and build a merged undirected model."""
    result = parse_swc(
        source,
        strict=strict,
        validate_reconnections=validate_reconnections,
        float_tol=float_tol,
    )
    return cls.from_parse_result(
        result,
        validate_reconnections=validate_reconnections,
        float_tol=float_tol,
    )

print_attributes(*, node_info=False, edge_info=False)

Print graph attributes and optional node/edge details.

Parameters

node_info: bool If True, print per-node attributes (n, x, y, z, r, t, merged_ids, lines where present). edge_info: bool If True, print all edges (u -- v) with edge attributes if any.

Source code in swcviz/model.py
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
def print_attributes(self, *, node_info: bool = False, edge_info: bool = False) -> None:
    """Print graph attributes and optional node/edge details.

    Parameters
    ----------
    node_info: bool
        If True, print per-node attributes (n, x, y, z, r, t, merged_ids, lines where present).
    edge_info: bool
        If True, print all edges (u -- v) with edge attributes if any.
    """
    info = _graph_attributes(self)
    header = (
        f"GeneralModel: nodes={info['nodes']}, edges={info['edges']}, "
        f"components={info['components']}, cycles={info['cycles']}, "
        f"branch_points={info['branch_points_count']}, leaves={info['leaves_count']}, "
        f"self_loops={info['self_loops']}, density={info['density']:.4f}"
    )
    print(header)

    if node_info:
        print("Nodes:")
        ordered = ["n", "x", "y", "z", "r", "t", "merged_ids", "lines"]
        for n, attrs in self.nodes(data=True):
            parts = [f"{k}={attrs[k]}" for k in ordered if k in attrs]
            print(f"  {n}: " + ", ".join(parts))

    if edge_info:
        print("Edges:")
        for u, v, attrs in self.edges(data=True):
            if attrs:
                print(f"  {u} -- {v}: {dict(attrs)}")
            else:
                print(f"  {u} -- {v}")

Segment dataclass

Oriented frustum segment between endpoints a and b.

Attributes

a, b: Point3 Endpoints in model/world coordinates. ra, rb: float Radii at a and b.

Source code in swcviz/geometry.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
@dataclass(frozen=True)
class Segment:
    """Oriented frustum segment between endpoints `a` and `b`.

    Attributes
    ----------
    a, b: Point3
        Endpoints in model/world coordinates.
    ra, rb: float
        Radii at `a` and `b`.
    """

    a: Point3
    b: Point3
    ra: float
    rb: float

    def vector(self) -> Vec3:
        return v_sub(self.b, self.a)

    def length(self) -> float:
        return v_norm(self.vector())

    def midpoint(self) -> Point3:
        return (
            self.a[0] * 0.5 + self.b[0] * 0.5,
            self.a[1] * 0.5 + self.b[1] * 0.5,
            self.a[2] * 0.5 + self.b[2] * 0.5,
        )

FrustaSet dataclass

A batched frusta mesh derived from a GeneralModel.

Attributes

vertices: List[Point3] Concatenated vertices for all frusta. faces: List[Face] Triangular faces indexing into vertices. sides: int Circumferential resolution used per frustum. end_caps: bool Whether end caps were included during construction. segment_count: int Number of segments used (one per graph edge). edge_count: int Alias for segment_count for clarity.

Source code in swcviz/geometry.py
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
@dataclass(frozen=True)
class FrustaSet:
    """A batched frusta mesh derived from a `GeneralModel`.

    Attributes
    ----------
    vertices: List[Point3]
        Concatenated vertices for all frusta.
    faces: List[Face]
        Triangular faces indexing into `vertices`.
    sides: int
        Circumferential resolution used per frustum.
    end_caps: bool
        Whether end caps were included during construction.
    segment_count: int
        Number of segments used (one per graph edge).
    edge_count: int
        Alias for `segment_count` for clarity.
    """

    vertices: List[Point3]
    faces: List[Face]
    sides: int
    end_caps: bool
    segment_count: int
    edge_count: int
    segments: List[Segment]

    @classmethod
    def from_general_model(
        cls,
        gm: Any,
        *,
        sides: int = 16,
        end_caps: bool = False,
    ) -> "FrustaSet":
        """Build a `FrustaSet` by converting each undirected edge into a `Segment`.

        Expects nodes to have attributes `x, y, z, r`.
        """
        segments: List[Segment] = []
        for u, v in gm.edges:
            xu, yu, zu = gm.nodes[u]["x"], gm.nodes[u]["y"], gm.nodes[u]["z"]
            xv, yv, zv = gm.nodes[v]["x"], gm.nodes[v]["y"], gm.nodes[v]["z"]
            ru, rv = float(gm.nodes[u]["r"]), float(gm.nodes[v]["r"])
            segments.append(Segment(a=(xu, yu, zu), b=(xv, yv, zv), ra=ru, rb=rv))

        vertices, faces = batch_frusta(segments, sides=sides, end_caps=end_caps)
        return cls(
            vertices=vertices,
            faces=faces,
            sides=sides,
            end_caps=end_caps,
            segment_count=len(segments),
            edge_count=len(segments),
            segments=segments,
        )

    def to_mesh3d_arrays(
        self,
    ) -> Tuple[List[float], List[float], List[float], List[int], List[int], List[int]]:
        """Return Plotly Mesh3d arrays: x, y, z, i, j, k."""
        x = [p[0] for p in self.vertices]
        y = [p[1] for p in self.vertices]
        z = [p[2] for p in self.vertices]
        i = [f[0] for f in self.faces]
        j = [f[1] for f in self.faces]
        k = [f[2] for f in self.faces]
        return x, y, z, i, j, k

    def scaled(self, radius_scale: float) -> "FrustaSet":
        """Return a new FrustaSet with all segment radii scaled by `radius_scale`.

        This rebuilds vertices/faces from the stored `segments` list.
        """
        if radius_scale == 1.0:
            return self
        scaled_segments = [
            Segment(a=s.a, b=s.b, ra=s.ra * radius_scale, rb=s.rb * radius_scale)
            for s in self.segments
        ]
        vertices, faces = batch_frusta(
            scaled_segments, sides=self.sides, end_caps=self.end_caps
        )
        return FrustaSet(
            vertices=vertices,
            faces=faces,
            sides=self.sides,
            end_caps=self.end_caps,
            segment_count=self.segment_count,
            edge_count=self.edge_count,
            segments=scaled_segments,
        )

from_general_model(gm, *, sides=16, end_caps=False) classmethod

Build a FrustaSet by converting each undirected edge into a Segment.

Expects nodes to have attributes x, y, z, r.

Source code in swcviz/geometry.py
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
@classmethod
def from_general_model(
    cls,
    gm: Any,
    *,
    sides: int = 16,
    end_caps: bool = False,
) -> "FrustaSet":
    """Build a `FrustaSet` by converting each undirected edge into a `Segment`.

    Expects nodes to have attributes `x, y, z, r`.
    """
    segments: List[Segment] = []
    for u, v in gm.edges:
        xu, yu, zu = gm.nodes[u]["x"], gm.nodes[u]["y"], gm.nodes[u]["z"]
        xv, yv, zv = gm.nodes[v]["x"], gm.nodes[v]["y"], gm.nodes[v]["z"]
        ru, rv = float(gm.nodes[u]["r"]), float(gm.nodes[v]["r"])
        segments.append(Segment(a=(xu, yu, zu), b=(xv, yv, zv), ra=ru, rb=rv))

    vertices, faces = batch_frusta(segments, sides=sides, end_caps=end_caps)
    return cls(
        vertices=vertices,
        faces=faces,
        sides=sides,
        end_caps=end_caps,
        segment_count=len(segments),
        edge_count=len(segments),
        segments=segments,
    )

to_mesh3d_arrays()

Return Plotly Mesh3d arrays: x, y, z, i, j, k.

Source code in swcviz/geometry.py
500
501
502
503
504
505
506
507
508
509
510
def to_mesh3d_arrays(
    self,
) -> Tuple[List[float], List[float], List[float], List[int], List[int], List[int]]:
    """Return Plotly Mesh3d arrays: x, y, z, i, j, k."""
    x = [p[0] for p in self.vertices]
    y = [p[1] for p in self.vertices]
    z = [p[2] for p in self.vertices]
    i = [f[0] for f in self.faces]
    j = [f[1] for f in self.faces]
    k = [f[2] for f in self.faces]
    return x, y, z, i, j, k

scaled(radius_scale)

Return a new FrustaSet with all segment radii scaled by radius_scale.

This rebuilds vertices/faces from the stored segments list.

Source code in swcviz/geometry.py
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
def scaled(self, radius_scale: float) -> "FrustaSet":
    """Return a new FrustaSet with all segment radii scaled by `radius_scale`.

    This rebuilds vertices/faces from the stored `segments` list.
    """
    if radius_scale == 1.0:
        return self
    scaled_segments = [
        Segment(a=s.a, b=s.b, ra=s.ra * radius_scale, rb=s.rb * radius_scale)
        for s in self.segments
    ]
    vertices, faces = batch_frusta(
        scaled_segments, sides=self.sides, end_caps=self.end_caps
    )
    return FrustaSet(
        vertices=vertices,
        faces=faces,
        sides=self.sides,
        end_caps=self.end_caps,
        segment_count=self.segment_count,
        edge_count=self.edge_count,
        segments=scaled_segments,
    )

PointSet dataclass

A batched mesh of small spheres placed at given 3D points.

Source code in swcviz/geometry.py
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
@dataclass(frozen=True)
class PointSet:
    """A batched mesh of small spheres placed at given 3D points."""

    vertices: List[Point3]
    faces: List[Face]
    points: List[Point3]
    base_radius: float
    stacks: int
    slices: int

    @classmethod
    def from_points(
        cls,
        points: Sequence[Point3],
        *,
        base_radius: float = 1.0,
        stacks: int = 6,
        slices: int = 12,
    ) -> "PointSet":
        verts, faces = batch_spheres(
            points, radius=base_radius, stacks=stacks, slices=slices
        )
        return cls(
            vertices=verts,
            faces=faces,
            points=list(points),
            base_radius=base_radius,
            stacks=stacks,
            slices=slices,
        )

    @classmethod
    def from_txt(
        cls,
        source: Union[str, os.PathLike, Iterable[str], io.TextIOBase],
        *,
        base_radius: float = 1.0,
        stacks: int = 6,
        slices: int = 12,
        allow_extra_columns: bool = True,
    ) -> "PointSet":
        """Load a simple text format with `x y z` coordinates per non-empty line.

        - Lines beginning with `#` or blank lines are ignored.
        - If `allow_extra_columns=True`, extra columns after the first three are ignored.
        - Raises `ValueError` on malformed lines.
        """

        # Normalize to an iterator of lines
        lines: Iterable[str]
        if hasattr(source, "read"):
            # file-like or IO stream; iterating yields lines
            lines = source  # type: ignore[assignment]
        elif isinstance(source, (str, os.PathLike)):
            # path or text
            p = str(source)
            if os.path.exists(p):
                with open(p, "r", encoding="utf-8") as f:
                    content = f.read().splitlines()
                lines = content
            else:
                lines = str(source).splitlines()
        else:
            lines = source

        pts: List[Point3] = []
        for idx, raw in enumerate(lines, start=1):
            s = raw.strip()
            if not s or s.startswith("#"):
                continue
            parts = s.split()
            if len(parts) < 3:
                raise ValueError(
                    f"Line {idx}: expected at least 3 columns for x y z, got {len(parts)}"
                )
            if not allow_extra_columns and len(parts) != 3:
                raise ValueError(
                    f"Line {idx}: expected exactly 3 columns for x y z, got {len(parts)}"
                )
            try:
                x = float(parts[0])
                y = float(parts[1])
                z = float(parts[2])
            except Exception as e:
                raise ValueError(f"Line {idx}: could not parse floats: {e}")
            pts.append((x, y, z))

        return cls.from_points(
            pts, base_radius=base_radius, stacks=stacks, slices=slices
        )

    def to_mesh3d_arrays(
        self,
    ) -> Tuple[List[float], List[float], List[float], List[int], List[int], List[int]]:
        x = [p[0] for p in self.vertices]
        y = [p[1] for p in self.vertices]
        z = [p[2] for p in self.vertices]
        i = [f[0] for f in self.faces]
        j = [f[1] for f in self.faces]
        k = [f[2] for f in self.faces]
        return x, y, z, i, j, k

    def scaled(self, radius_scale: float) -> "PointSet":
        """Return a new `PointSet` with all sphere radii scaled by `radius_scale`."""
        if radius_scale == 1.0:
            return self
        r = self.base_radius * radius_scale
        verts, faces = batch_spheres(
            self.points, radius=r, stacks=self.stacks, slices=self.slices
        )
        return PointSet(
            vertices=verts,
            faces=faces,
            points=self.points,
            base_radius=self.base_radius,
            stacks=self.stacks,
            slices=self.slices,
        )

from_txt(source, *, base_radius=1.0, stacks=6, slices=12, allow_extra_columns=True) classmethod

Load a simple text format with x y z coordinates per non-empty line.

  • Lines beginning with # or blank lines are ignored.
  • If allow_extra_columns=True, extra columns after the first three are ignored.
  • Raises ValueError on malformed lines.
Source code in swcviz/geometry.py
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
@classmethod
def from_txt(
    cls,
    source: Union[str, os.PathLike, Iterable[str], io.TextIOBase],
    *,
    base_radius: float = 1.0,
    stacks: int = 6,
    slices: int = 12,
    allow_extra_columns: bool = True,
) -> "PointSet":
    """Load a simple text format with `x y z` coordinates per non-empty line.

    - Lines beginning with `#` or blank lines are ignored.
    - If `allow_extra_columns=True`, extra columns after the first three are ignored.
    - Raises `ValueError` on malformed lines.
    """

    # Normalize to an iterator of lines
    lines: Iterable[str]
    if hasattr(source, "read"):
        # file-like or IO stream; iterating yields lines
        lines = source  # type: ignore[assignment]
    elif isinstance(source, (str, os.PathLike)):
        # path or text
        p = str(source)
        if os.path.exists(p):
            with open(p, "r", encoding="utf-8") as f:
                content = f.read().splitlines()
            lines = content
        else:
            lines = str(source).splitlines()
    else:
        lines = source

    pts: List[Point3] = []
    for idx, raw in enumerate(lines, start=1):
        s = raw.strip()
        if not s or s.startswith("#"):
            continue
        parts = s.split()
        if len(parts) < 3:
            raise ValueError(
                f"Line {idx}: expected at least 3 columns for x y z, got {len(parts)}"
            )
        if not allow_extra_columns and len(parts) != 3:
            raise ValueError(
                f"Line {idx}: expected exactly 3 columns for x y z, got {len(parts)}"
            )
        try:
            x = float(parts[0])
            y = float(parts[1])
            z = float(parts[2])
        except Exception as e:
            raise ValueError(f"Line {idx}: could not parse floats: {e}")
        pts.append((x, y, z))

    return cls.from_points(
        pts, base_radius=base_radius, stacks=stacks, slices=slices
    )

scaled(radius_scale)

Return a new PointSet with all sphere radii scaled by radius_scale.

Source code in swcviz/geometry.py
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
def scaled(self, radius_scale: float) -> "PointSet":
    """Return a new `PointSet` with all sphere radii scaled by `radius_scale`."""
    if radius_scale == 1.0:
        return self
    r = self.base_radius * radius_scale
    verts, faces = batch_spheres(
        self.points, radius=r, stacks=self.stacks, slices=self.slices
    )
    return PointSet(
        vertices=verts,
        faces=faces,
        points=self.points,
        base_radius=self.base_radius,
        stacks=self.stacks,
        slices=self.slices,
    )

parse_swc(source, *, strict=True, validate_reconnections=True, float_tol=1e-09)

Parse an SWC file or text stream.

Parameters

source Path to an SWC file, a file-like object, an iterable of lines, or a string containing SWC content. strict If True, enforce 7-column rows and validate parent references exist. validate_reconnections If True, ensure reconnection node pairs share identical (x, y, z, r). float_tol Tolerance used when comparing floating-point coordinates/radii.

Returns

SWCParseResult Parsed records, reconnection pairs, and collected comments.

Raises

ValueError If parsing or validation fails. FileNotFoundError If a string path is provided that does not exist.

Source code in swcviz/io.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def parse_swc(
    source: Union[str, os.PathLike, Iterable[str], io.TextIOBase],
    *,
    strict: bool = True,
    validate_reconnections: bool = True,
    float_tol: float = 1e-9,
) -> SWCParseResult:
    """Parse an SWC file or text stream.

    Parameters
    ----------
    source
        Path to an SWC file, a file-like object, an iterable of lines, or a string
        containing SWC content.
    strict
        If True, enforce 7-column rows and validate parent references exist.
    validate_reconnections
        If True, ensure reconnection node pairs share identical (x, y, z, r).
    float_tol
        Tolerance used when comparing floating-point coordinates/radii.

    Returns
    -------
    SWCParseResult
        Parsed records, reconnection pairs, and collected comments.

    Raises
    ------
    ValueError
        If parsing or validation fails.
    FileNotFoundError
        If a string path is provided that does not exist.
    """
    records: Dict[int, SWCRecord] = {}
    comments: List[str] = []
    reconnections: List[Tuple[int, int]] = []

    for lineno, raw in _iter_lines(source):
        line = raw.strip()
        if not line:
            continue
        if line.startswith("#"):
            comments.append(raw.rstrip("\n"))
            m = _RECONNECT_RE.match(raw)
            if m:
                i = int(m.group("i"))
                j = int(m.group("j"))
                # Normalize order for stable results
                a, b = sorted((i, j))
                reconnections.append((a, b))
            continue

        parts = line.split()
        if len(parts) < 7:
            raise ValueError(
                f"Line {lineno}: expected 7 columns 'n T x y z r parent', got {len(parts)}"
            )
        if strict and len(parts) > 7:
            raise ValueError(
                f"Line {lineno}: expected exactly 7 columns, got {len(parts)}"
            )

        try:
            n = int(_coerce_int(parts[0]))
            t = int(_coerce_int(parts[1]))
            x = float(parts[2])
            y = float(parts[3])
            z = float(parts[4])
            r = float(parts[5])
            parent = int(_coerce_int(parts[6]))
        except Exception as e:  # noqa: BLE001
            raise ValueError(f"Line {lineno}: failed to parse values -> {e}") from e

        if n in records:
            prev = records[n]
            raise ValueError(
                f"Line {lineno}: duplicate node id {n} (previously defined at line {prev.line})"
            )

        records[n] = SWCRecord(n=n, t=t, x=x, y=y, z=z, r=r, parent=parent, line=lineno)

    # Validation: parent references
    if strict:
        for rec in records.values():
            if rec.parent == -1:
                continue
            if rec.parent not in records:
                raise ValueError(
                    f"Line {rec.line}: parent id {rec.parent} does not exist for node {rec.n}"
                )

    # Validation: reconnections require identical xyzr
    if validate_reconnections and reconnections:
        for a, b in reconnections:
            if a not in records or b not in records:
                raise ValueError(
                    f"Reconnection pair ({a}, {b}) refers to undefined node id(s)"
                )
            ra, rb = records[a], records[b]
            if not (
                _close(ra.x, rb.x, float_tol)
                and _close(ra.y, rb.y, float_tol)
                and _close(ra.z, rb.z, float_tol)
                and _close(ra.r, rb.r, float_tol)
            ):
                raise ValueError(
                    "Reconnection requires identical (x, y, z, r) but got:\n"
                    f"  {a}: (x={ra.x}, y={ra.y}, z={ra.z}, r={ra.r})\n"
                    f"  {b}: (x={rb.x}, y={rb.y}, z={rb.z}, r={rb.r})"
                )

    return SWCParseResult(
        records=records, reconnections=reconnections, comments=comments
    )

frustum_mesh(seg, *, sides=16, end_caps=False)

Generate a frustum mesh for a single Segment.

Returns

(vertices, faces): - vertices: List[Point3] - faces: List[Face], each = (i, j, k) indexing into vertices

Source code in swcviz/geometry.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def frustum_mesh(
    seg: Segment, *, sides: int = 16, end_caps: bool = False
) -> Tuple[List[Point3], List[Face]]:
    """Generate a frustum mesh for a single `Segment`.

    Returns
    -------
    (vertices, faces):
        - vertices: List[Point3]
        - faces: List[Face], each = (i, j, k) indexing into `vertices`
    """
    # Local frame
    U, V, W = _orthonormal_frame(seg.vector())

    ring_a = _circle_ring(seg.a, seg.ra, U, V, sides)
    ring_b = _circle_ring(seg.b, seg.rb, U, V, sides)

    vertices: List[Point3] = []
    vertices.extend(ring_a)
    vertices.extend(ring_b)

    faces: List[Face] = []

    # Side faces (two triangles per quad)
    for i in range(sides):
        a0 = i
        a1 = (i + 1) % sides
        b0 = i + sides
        b1 = ((i + 1) % sides) + sides
        faces.append((a0, b0, b1))
        faces.append((a0, b1, a1))

    # Optional end caps
    if end_caps and seg.ra > 0.0:
        ca = len(vertices)
        vertices.append(seg.a)
        for i in range(sides):
            a0 = i
            a1 = (i + 1) % sides
            # Wind towards center for cap
            faces.append((ca, a1, a0))

    if end_caps and seg.rb > 0.0:
        cb = len(vertices)
        vertices.append(seg.b)
        for i in range(sides):
            b0 = i + sides
            b1 = ((i + 1) % sides) + sides
            faces.append((cb, b0, b1))

    return vertices, faces

batch_frusta(segments, *, sides=16, end_caps=False)

Batch multiple frusta into a single mesh.

Returns a concatenated list of vertices and faces with the proper index offsets.

Source code in swcviz/geometry.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def batch_frusta(
    segments: Iterable[Segment], *, sides: int = 16, end_caps: bool = False
) -> Tuple[List[Point3], List[Face]]:
    """Batch multiple frusta into a single mesh.

    Returns a concatenated list of `vertices` and `faces` with the proper index offsets.
    """
    all_vertices: List[Point3] = []
    all_faces: List[Face] = []
    offset = 0

    for seg in segments:
        v, f = frustum_mesh(seg, sides=sides, end_caps=end_caps)
        all_vertices.extend(v)
        # Re-index faces
        all_faces.extend([(a + offset, b + offset, c + offset) for (a, b, c) in f])
        offset += len(v)

    return all_vertices, all_faces

plot_centroid(gm, *, marker_size=2.0, line_width=2.0, show_nodes=True)

Plot centroid skeleton from a GeneralModel.

Edges are drawn as line segments in 3D using Scatter3d.

Source code in swcviz/viz.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def plot_centroid(gm, *, marker_size: float = 2.0, line_width: float = 2.0, show_nodes: bool = True) -> go.Figure:
    """Plot centroid skeleton from a GeneralModel.

    Edges are drawn as line segments in 3D using Scatter3d.
    """
    xs = []
    ys = []
    zs = []

    # Build polyline segments with None separators for Plotly
    for u, v in gm.edges:
        xs.extend([gm.nodes[u]["x"], gm.nodes[v]["x"], None])
        ys.extend([gm.nodes[u]["y"], gm.nodes[v]["y"], None])
        zs.extend([gm.nodes[u]["z"], gm.nodes[v]["z"], None])

    edge_trace = go.Scatter3d(
        x=xs,
        y=ys,
        z=zs,
        mode="lines",
        line=dict(width=line_width, color="#1f77b4"),
        name="edges",
    )

    data = [edge_trace]

    if show_nodes:
        xn = [gm.nodes[n]["x"] for n in gm.nodes]
        yn = [gm.nodes[n]["y"] for n in gm.nodes]
        zn = [gm.nodes[n]["z"] for n in gm.nodes]
        node_trace = go.Scatter3d(
            x=xn,
            y=yn,
            z=zn,
            mode="markers",
            marker=dict(size=marker_size, color="#ff7f0e"),
            name="nodes",
        )
        data.append(node_trace)

    fig = go.Figure(data=data)
    apply_layout(fig, title="Centroid Skeleton")
    return fig

plot_frusta(frusta, *, color='lightblue', opacity=0.8, flatshading=True, radius_scale=1.0)

Plot a FrustaSet as a Mesh3d figure.

Parameters

frusta: FrustaSet Batched frusta mesh to render. color: str Mesh color. opacity: float Mesh opacity. flatshading: bool Whether to enable flat shading. radius_scale: float Uniform scale applied to all segment radii before meshing (1.0 = no change).

Source code in swcviz/viz.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def plot_frusta(
    frusta: FrustaSet,
    *,
    color: str = "lightblue",
    opacity: float = 0.8,
    flatshading: bool = True,
    radius_scale: float = 1.0,
) -> go.Figure:
    """Plot a FrustaSet as a Mesh3d figure.

    Parameters
    ----------
    frusta: FrustaSet
        Batched frusta mesh to render.
    color: str
        Mesh color.
    opacity: float
        Mesh opacity.
    flatshading: bool
        Whether to enable flat shading.
    radius_scale: float
        Uniform scale applied to all segment radii before meshing (1.0 = no change).
    """
    fr = frusta if radius_scale == 1.0 else frusta.scaled(radius_scale)
    x, y, z, i, j, k = fr.to_mesh3d_arrays()
    mesh = go.Mesh3d(
        x=x,
        y=y,
        z=z,
        i=i,
        j=j,
        k=k,
        color=color,
        opacity=opacity,
        flatshading=flatshading,
    )
    fig = go.Figure(data=[mesh])
    apply_layout(fig, title="Frusta Mesh")
    return fig

plot_frusta_with_centroid(gm, frusta, *, color='lightblue', opacity=0.8, flatshading=True, radius_scale=1.0, centroid_color='#1f77b4', centroid_line_width=2.0, show_nodes=False, node_size=2.0)

Overlay frusta mesh with centroid skeleton from a GeneralModel.

Parameters mirror plot_centroid and plot_frusta with an extra radius_scale.

Source code in swcviz/viz.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def plot_frusta_with_centroid(
    gm,
    frusta: FrustaSet,
    *,
    color: str = "lightblue",
    opacity: float = 0.8,
    flatshading: bool = True,
    radius_scale: float = 1.0,
    centroid_color: str = "#1f77b4",
    centroid_line_width: float = 2.0,
    show_nodes: bool = False,
    node_size: float = 2.0,
) -> go.Figure:
    """Overlay frusta mesh with centroid skeleton from a `GeneralModel`.

    Parameters mirror `plot_centroid` and `plot_frusta` with an extra `radius_scale`.
    """
    # Build centroid polyline
    xs, ys, zs = [], [], []
    for u, v in gm.edges:
        xs.extend([gm.nodes[u]["x"], gm.nodes[v]["x"], None])
        ys.extend([gm.nodes[u]["y"], gm.nodes[v]["y"], None])
        zs.extend([gm.nodes[u]["z"], gm.nodes[v]["z"], None])
    centroid = go.Scatter3d(
        x=xs,
        y=ys,
        z=zs,
        mode="lines",
        line=dict(width=centroid_line_width, color=centroid_color),
        name="centroid",
    )

    traces = [centroid]
    if show_nodes:
        xn = [gm.nodes[n]["x"] for n in gm.nodes]
        yn = [gm.nodes[n]["y"] for n in gm.nodes]
        zn = [gm.nodes[n]["z"] for n in gm.nodes]
        nodes = go.Scatter3d(
            x=xn,
            y=yn,
            z=zn,
            mode="markers",
            marker=dict(size=node_size, color="#ff7f0e"),
            name="nodes",
        )
        traces.append(nodes)

    # Frusta mesh (optionally scaled)
    fr = frusta if radius_scale == 1.0 else frusta.scaled(radius_scale)
    x, y, z, i, j, k = fr.to_mesh3d_arrays()
    mesh = go.Mesh3d(
        x=x,
        y=y,
        z=z,
        i=i,
        j=j,
        k=k,
        color=color,
        opacity=opacity,
        flatshading=flatshading,
        name="frusta",
    )
    traces.append(mesh)

    fig = go.Figure(data=traces)
    apply_layout(fig, title="Centroid + Frusta")
    return fig

plot_frusta_slider(frusta, *, color='lightblue', opacity=0.8, flatshading=True, min_scale=0.0, max_scale=1.0, steps=21)

Interactive slider (0..1 default) controlling uniform radius_scale.

Precomputes frames at evenly spaced scales between min_scale and max_scale.

Source code in swcviz/viz.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def plot_frusta_slider(
    frusta: FrustaSet,
    *,
    color: str = "lightblue",
    opacity: float = 0.8,
    flatshading: bool = True,
    min_scale: float = 0.0,
    max_scale: float = 1.0,
    steps: int = 21,
) -> go.Figure:
    """Interactive slider (0..1 default) controlling uniform `radius_scale`.

    Precomputes frames at evenly spaced scales between `min_scale` and `max_scale`.
    """
    steps = max(2, int(steps))
    span = max_scale - min_scale
    scales = [min_scale + (span * k / (steps - 1)) for k in range(steps)]

    # Use i/j/k topology from the unscaled mesh
    base = frusta
    bx, by, bz, bi, bj, bk = base.to_mesh3d_arrays()

    # Initial view: prefer scale = 1.0 if within range; otherwise first scale
    if min_scale <= 1.0 <= max_scale:
        init_idx = min(range(len(scales)), key=lambda idx: abs(scales[idx] - 1.0))
    else:
        init_idx = 0
    init_scale = scales[init_idx]
    init_fr = base if init_scale == 1.0 else base.scaled(init_scale)
    x0, y0, z0, _, _, _ = init_fr.to_mesh3d_arrays()

    mesh = go.Mesh3d(
        x=x0,
        y=y0,
        z=z0,
        i=bi,
        j=bj,
        k=bk,
        color=color,
        opacity=opacity,
        flatshading=flatshading,
        name="frusta",
    )

    frames = []
    for s in scales:
        fr_s = base if s == 1.0 else base.scaled(s)
        xs, ys, zs, _, _, _ = fr_s.to_mesh3d_arrays()
        frames.append(
            go.Frame(
                name=f"scale={s:.2f}",
                data=[go.Mesh3d(x=xs, y=ys, z=zs, i=bi, j=bj, k=bk, color=color, opacity=opacity, flatshading=flatshading)],
            )
        )

    # Slider and play controls
    slider_steps = [
        {
            "label": f"{s:.2f}",
            "method": "animate",
            "args": [[f"scale={s:.2f}"], {"mode": "immediate", "frame": {"duration": 0}, "transition": {"duration": 0}}],
        }
        for s in scales
    ]

    sliders = [
        {
            "active": init_idx,
            "currentvalue": {"prefix": "radius_scale: ", "visible": True},
            "steps": slider_steps,
        }
    ]

    updatemenus = [
        {
            "type": "buttons",
            "direction": "left",
            "pad": {"r": 10, "t": 60},
            "showactive": False,
            "x": 0.0,
            "y": 0,
            "buttons": [
                {"label": "▶ Play", "method": "animate", "args": [None, {"fromcurrent": True, "frame": {"duration": 0}, "transition": {"duration": 0}}]},
                {"label": "❚❚ Pause", "method": "animate", "args": [[None], {"mode": "immediate", "frame": {"duration": 0}, "transition": {"duration": 0}}]},
            ],
        }
    ]

    fig = go.Figure(data=[mesh], frames=frames)
    apply_layout(fig, title="Frusta Mesh — radius_scale slider")
    fig.update_layout(sliders=sliders, updatemenus=updatemenus)
    return fig

plot_model(*, gm=None, frusta=None, show_frusta=True, show_centroid=True, sides=16, end_caps=False, color='lightblue', opacity=0.8, flatshading=True, radius_scale=1.0, slider=False, min_scale=0.0, max_scale=1.0, steps=21, centroid_color='#1f77b4', centroid_line_width=2.0, show_nodes=False, node_size=2.0, point_set=None, point_size=1.0, point_color='#d62728')

Master visualization combining centroid, frusta, slider, and overlay points.

  • If frusta is not provided and gm is, a FrustaSet is built from gm.
  • If slider=True and show_frusta=True, a Plotly slider controls radius_scale.
  • points overlays arbitrary xyz positions as small markers.
Source code in swcviz/viz.py
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
def plot_model(
    *,
    gm=None,
    frusta: FrustaSet | None = None,
    show_frusta: bool = True,
    show_centroid: bool = True,
    # Frusta build options (used if frusta is None and gm provided)
    sides: int = 16,
    end_caps: bool = False,
    # Frusta appearance
    color: str = "lightblue",
    opacity: float = 0.8,
    flatshading: bool = True,
    # Scaling and interactivity
    radius_scale: float = 1.0,
    slider: bool = False,
    min_scale: float = 0.0,
    max_scale: float = 1.0,
    steps: int = 21,
    # Centroid appearance
    centroid_color: str = "#1f77b4",
    centroid_line_width: float = 2.0,
    show_nodes: bool = False,
    node_size: float = 2.0,
    # Extra points overlay (as low-res spheres)
    point_set: PointSet | None = None,
    point_size: float = 1.0,
    point_color: str = "#d62728",
) -> go.Figure:
    """Master visualization combining centroid, frusta, slider, and overlay points.

    - If `frusta` is not provided and `gm` is, a `FrustaSet` is built from `gm`.
    - If `slider=True` and `show_frusta=True`, a Plotly slider controls `radius_scale`.
    - `points` overlays arbitrary xyz positions as small markers.
    """

    traces: list[go.BaseTraceType] = []
    frames: list[go.Frame] | None = None

    # Build frusta if needed
    base_fr = frusta
    if show_frusta and base_fr is None:
        if gm is None:
            raise ValueError("plot_model: provide either `frusta` or a `gm` to build from")
        base_fr = FrustaSet.from_general_model(gm, sides=sides, end_caps=end_caps)

    # Centroid traces
    if show_centroid and gm is not None:
        xs, ys, zs = [], [], []
        for u, v in gm.edges:
            xs.extend([gm.nodes[u]["x"], gm.nodes[v]["x"], None])
            ys.extend([gm.nodes[u]["y"], gm.nodes[v]["y"], None])
            zs.extend([gm.nodes[u]["z"], gm.nodes[v]["z"], None])
        centroid = go.Scatter3d(
            x=xs,
            y=ys,
            z=zs,
            mode="lines",
            line=dict(width=centroid_line_width, color=centroid_color),
            name="centroid",
        )
        traces.append(centroid)

        if show_nodes:
            xn = [gm.nodes[n]["x"] for n in gm.nodes]
            yn = [gm.nodes[n]["y"] for n in gm.nodes]
            zn = [gm.nodes[n]["z"] for n in gm.nodes]
            nodes = go.Scatter3d(
                x=xn,
                y=yn,
                z=zn,
                mode="markers",
                marker=dict(size=node_size, color="#ff7f0e"),
                name="nodes",
            )
            traces.append(nodes)

    # Overlay points as small spheres mesh
    if point_set is not None:
        ps = point_set if point_size == 1.0 else point_set.scaled(point_size)
        px, py, pz, pi, pj, pk = ps.to_mesh3d_arrays()
        pts_mesh = go.Mesh3d(
            x=px,
            y=py,
            z=pz,
            i=pi,
            j=pj,
            k=pk,
            color=point_color,
            opacity=1.0,
            flatshading=True,
            name="points",
        )
        # Keep points above centroid but above frusta ordering set below
        traces.append(pts_mesh)

    # Frusta (optionally with slider)
    if show_frusta and base_fr is not None:
        # Use base topology and update x/y/z with radius scales
        bx, by, bz, bi, bj, bk = base_fr.to_mesh3d_arrays()

        if slider:
            span = max_scale - min_scale
            steps = max(2, int(steps))
            scales = [min_scale + (span * k / (steps - 1)) for k in range(steps)]
            # Pick initial scale near 1.0 if in range
            if min_scale <= 1.0 <= max_scale:
                init_idx = min(range(len(scales)), key=lambda idx: abs(scales[idx] - 1.0))
            else:
                init_idx = 0
            init_scale = scales[init_idx]
            init_fr = base_fr if init_scale == 1.0 else base_fr.scaled(init_scale)
            x0, y0, z0, _, _, _ = init_fr.to_mesh3d_arrays()

            mesh = go.Mesh3d(
                x=x0,
                y=y0,
                z=z0,
                i=bi,
                j=bj,
                k=bk,
                color=color,
                opacity=opacity,
                flatshading=flatshading,
                name="frusta",
            )

            # Ensure mesh is the FIRST trace so frames can update just this trace
            traces = [mesh] + traces

            frames = []
            for s in scales:
                fr_s = base_fr if s == 1.0 else base_fr.scaled(s)
                xs, ys, zs, _, _, _ = fr_s.to_mesh3d_arrays()
                frames.append(
                    go.Frame(
                        name=f"scale={s:.2f}",
                        data=[go.Mesh3d(x=xs, y=ys, z=zs, i=bi, j=bj, k=bk, color=color, opacity=opacity, flatshading=flatshading)],
                    )
                )

            slider_steps = [
                {
                    "label": f"{s:.2f}",
                    "method": "animate",
                    "args": [[f"scale={s:.2f}"], {"mode": "immediate", "frame": {"duration": 0}, "transition": {"duration": 0}}],
                }
                for s in scales
            ]

            sliders = [
                {
                    "active": init_idx,
                    "currentvalue": {"prefix": "radius_scale: ", "visible": True},
                    "steps": slider_steps,
                }
            ]

            updatemenus = [
                {
                    "type": "buttons",
                    "direction": "left",
                    "pad": {"r": 10, "t": 60},
                    "showactive": False,
                    "x": 0.0,
                    "y": 0,
                    "buttons": [
                        {"label": "▶ Play", "method": "animate", "args": [None, {"fromcurrent": True, "frame": {"duration": 0}, "transition": {"duration": 0}}]},
                        {"label": "❚❚ Pause", "method": "animate", "args": [[None], {"mode": "immediate", "frame": {"duration": 0}, "transition": {"duration": 0}}]},
                    ],
                }
            ]

            fig = go.Figure(data=traces, frames=frames)
            apply_layout(fig, title="Model")
            fig.update_layout(sliders=sliders, updatemenus=updatemenus)
            return fig
        else:
            # Static radius scale
            fr = base_fr if radius_scale == 1.0 else base_fr.scaled(radius_scale)
            x, y, z, i, j, k = fr.to_mesh3d_arrays()
            mesh = go.Mesh3d(
                x=x,
                y=y,
                z=z,
                i=i,
                j=j,
                k=k,
                color=color,
                opacity=opacity,
                flatshading=flatshading,
                name="frusta",
            )
            traces.insert(0, mesh)  # keep mesh on bottom for visibility

    fig = go.Figure(data=traces)
    apply_layout(fig, title="Model")
    return fig

get_config()

Return the current visualization configuration (live object).

Source code in swcviz/config.py
38
39
40
def get_config() -> VizConfig:
    """Return the current visualization configuration (live object)."""
    return _config

set_config(**kwargs)

Update global visualization configuration.

Example

set_config(width=800, height=600, scene_aspectmode="cube")

Source code in swcviz/config.py
43
44
45
46
47
48
49
50
51
52
def set_config(**kwargs: Any) -> None:
    """Update global visualization configuration.

    Example:
        set_config(width=800, height=600, scene_aspectmode="cube")
    """
    for k, v in kwargs.items():
        if not hasattr(_config, k):
            raise AttributeError(f"Unknown viz config key: {k}")
        setattr(_config, k, v)

apply_layout(fig, *, title=None)

Apply global layout defaults to a Plotly figure in-place.

Source code in swcviz/config.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def apply_layout(fig, *, title: str | None = None) -> None:
    """Apply global layout defaults to a Plotly figure in-place."""
    # Determine aspect mode (force equal axes if requested)
    aspectmode = "data" if _config.force_equal_axes else _config.scene_aspectmode

    fig.update_layout(
        width=_config.width,
        height=_config.height,
        template=_config.template,
        margin=_config.margin,
        showlegend=_config.showlegend,
        scene_aspectmode=aspectmode,
    )
    # Only used when manual aspect is requested
    if aspectmode == "manual":
        fig.update_layout(scene=dict(aspectratio=_config.scene_aspectratio))
    if title is not None:
        fig.update_layout(title=title)