Skip to content

Visualization API

vis

Visualization utilities for FRUST data and molecular structures.

The package keeps the historical frust.vis import surface while grouping implementation code by visualization type.

plot_energy_profile

plot_energy_profile(states, ylabel: str = 'ΔG (kcal/mol)', n_points: int = 500, figsize=(8, 3.5), annotate_energies: bool = True, decimals: int = 1, int_prefix: str = 'int', label_offset_up: int = 8, label_offset_down: int = 12, hide_y_ticks: bool = True, hide_x_ticks: bool = True, hide_spines: bool = True, grid: bool = False, ax=None, dummy_substr: str = 'dummy', dummy_alpha: float = 0.5, side_token: str = 'side-rxn', show_main_to_product: bool = True, main_to_product_alpha: float = 1, main_to_product_linestyle: str = ':', main_to_product_lw: float = 3.0, main_to_product_bow: float = 0.5, main_to_product_drop_frac: float = 0.65, main_to_product_drop_points: int | None = None, main_to_product_flat_points: int | None = None, product_x_offset: float = 0.18, overlay: str = 'auto', overlay_annotate: str = 'energy', overlay_alpha: float = 0.35, overlay_lw_scale: float = 1.0, marker: str = 'o', overlay_markers=None, show_legend: bool = True, profile_label: str | None = None, overlay_colors=None, same_energy_tol: float = 0.001, same_energy_mode: str = 'hide', same_energy_tag: str = '≡', show_state_labels: bool | None = None, state_label_rotation: float = 0.0, font_size: float | None = None, state_label_fontsize: float | None = None, energy_fontsize: float | None = None, axis_label_fontsize: float | None = None, tick_label_fontsize: float | None = None, legend_fontsize: float | None = None, same_energy_tag_fontsize: float | None = None, state_label_pad: float = 6.0)

Plot one or more reaction energy profiles.

Parameters:

Name Type Description Default
states

Single profile as a sequence of (label, energy[, placement]) entries, or multiple profiles as a mapping/list of (profile_name, states). A string marker such as "side-rxn@int2@0.6#Side product" starts a side-reaction segment.

required
ylabel str

Y-axis label.

'ΔG (kcal/mol)'
n_points int

Number of interpolation points used for smooth profile curves.

500
figsize

Figure size used when ax is not provided.

(8, 3.5)
annotate_energies bool

Whether to annotate energies for the reference profile.

True
decimals int

Number of decimal places shown in energy labels.

1
int_prefix str

Label prefix treated as an intermediate for default label placement.

'int'
label_offset_up int

Point offsets used for top and bottom annotations.

8
label_offset_down int

Point offsets used for top and bottom annotations.

8
hide_y_ticks bool

Axis cleanup options.

True
hide_x_ticks bool

Axis cleanup options.

True
hide_spines bool

Axis cleanup options.

True
grid bool

Whether to show the Matplotlib grid.

False
ax

Existing Matplotlib axes. If omitted, a new figure and axes are created.

None
dummy_substr str

Substring used to detect dummy states that should render with reduced annotation alpha.

'dummy'
dummy_alpha float

Alpha multiplier for dummy-state annotations.

0.5
side_token str

String token that starts side-reaction parsing.

'side-rxn'
show_main_to_product bool

Whether to draw the dotted main-path connection to the product after a side-reaction branch.

True
main_to_product_alpha float

Style controls for the main-to-product connector.

1
main_to_product_linestyle float

Style controls for the main-to-product connector.

1
main_to_product_lw float

Style controls for the main-to-product connector.

1
main_to_product_drop_frac float

Fraction of the connector x-distance kept flat before dropping to the product energy.

0.65
main_to_product_drop_points int | None

Optional explicit point counts for the connector segments.

None
main_to_product_flat_points int | None

Optional explicit point counts for the connector segments.

None
product_x_offset float

Horizontal spacing between multiple product-like states.

0.18
overlay str

Overlay mode: "auto", "off", or "on".

'auto'
overlay_annotate str

Annotation mode for overlay profiles: "none", "energy", or "full".

'energy'
overlay_alpha float

Alpha and line-width scaling for overlay profiles.

0.35
overlay_lw_scale float

Alpha and line-width scaling for overlay profiles.

0.35
marker str

Marker style for reference and overlay points.

'o'
overlay_markers str

Marker style for reference and overlay points.

'o'
show_legend bool

Whether to draw a legend.

True
profile_label str | None

Legend label for a single profile.

None
overlay_colors

Optional overlay color mapping or sequence. A two-item tuple sets (main_color, side_color).

None
same_energy_tol float

Controls for suppressing or tagging matching overlay energies.

0.001
same_energy_mode float

Controls for suppressing or tagging matching overlay energies.

0.001
same_energy_tag float

Controls for suppressing or tagging matching overlay energies.

0.001
show_state_labels bool | None

X-axis state-label controls.

None
state_label_rotation bool | None

X-axis state-label controls.

None
state_label_pad bool | None

X-axis state-label controls.

None
font_size float | None
None
state_label_fontsize float | None
None
energy_fontsize float | None
None
axis_label_fontsize float | None
None
tick_label_fontsize float | None

Font-size controls. Specific values override font_size.

None
legend_fontsize float | None

Font-size controls. Specific values override font_size.

None
same_energy_tag_fontsize float | None

Font-size controls. Specific values override font_size.

None

Returns:

Type Description
tuple

(fig, ax) for the Matplotlib figure and axes.

Source code in frust/vis/energy_profile/api.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
def plot_energy_profile(
    states,
    ylabel: str = "ΔG (kcal/mol)",
    n_points: int = 500,
    figsize=(8, 3.5),
    annotate_energies: bool = True,
    decimals: int = 1,
    int_prefix: str = "int",
    label_offset_up: int = 8,
    label_offset_down: int = 12,
    hide_y_ticks: bool = True,
    hide_x_ticks: bool = True,
    hide_spines: bool = True,
    grid: bool = False,
    ax=None,
    dummy_substr: str = "dummy",
    dummy_alpha: float = 0.5,
    side_token: str = "side-rxn",
    show_main_to_product: bool = True,
    main_to_product_alpha: float = 1,
    main_to_product_linestyle: str = ":",
    main_to_product_lw: float = 3.0,
    main_to_product_bow: float = 0.5,
    main_to_product_drop_frac: float = 0.65,
    main_to_product_drop_points: int | None = None,
    main_to_product_flat_points: int | None = None,
    product_x_offset: float = 0.18,
    # --- multi-molecule overlays ---
    overlay: str = "auto",  # "auto" | "off" | "on"
    overlay_annotate: str = "energy",  # "none" | "energy" | "full"
    overlay_alpha: float = 0.35,
    overlay_lw_scale: float = 1.0,
    marker: str = "o",
    overlay_markers=None,
    show_legend: bool = True,
    profile_label: str | None = None,
    overlay_colors=None,
    same_energy_tol: float = 1e-3,
    same_energy_mode: str = "hide",  # "hide" | "tag"
    same_energy_tag: str = "≡",
    # --- NEW: bottom state labels (recommended for overlays) ---
    show_state_labels: bool | None = None,
    state_label_rotation: float = 0.0,
    font_size: float | None = None,
    state_label_fontsize: float | None = None,
    energy_fontsize: float | None = None,
    axis_label_fontsize: float | None = None,
    tick_label_fontsize: float | None = None,
    legend_fontsize: float | None = None,
    same_energy_tag_fontsize: float | None = None,
    state_label_pad: float = 6.0,
):
    """Plot one or more reaction energy profiles.

    Parameters
    ----------
    states
        Single profile as a sequence of ``(label, energy[, placement])`` entries,
        or multiple profiles as a mapping/list of ``(profile_name, states)``.
        A string marker such as ``"side-rxn@int2@0.6#Side product"`` starts a
        side-reaction segment.
    ylabel
        Y-axis label.
    n_points
        Number of interpolation points used for smooth profile curves.
    figsize
        Figure size used when `ax` is not provided.
    annotate_energies
        Whether to annotate energies for the reference profile.
    decimals
        Number of decimal places shown in energy labels.
    int_prefix
        Label prefix treated as an intermediate for default label placement.
    label_offset_up, label_offset_down
        Point offsets used for top and bottom annotations.
    hide_y_ticks, hide_x_ticks, hide_spines
        Axis cleanup options.
    grid
        Whether to show the Matplotlib grid.
    ax
        Existing Matplotlib axes. If omitted, a new figure and axes are created.
    dummy_substr
        Substring used to detect dummy states that should render with reduced
        annotation alpha.
    dummy_alpha
        Alpha multiplier for dummy-state annotations.
    side_token
        String token that starts side-reaction parsing.
    show_main_to_product
        Whether to draw the dotted main-path connection to the product after a
        side-reaction branch.
    main_to_product_alpha, main_to_product_linestyle, main_to_product_lw
        Style controls for the main-to-product connector.
    main_to_product_drop_frac
        Fraction of the connector x-distance kept flat before dropping to the
        product energy.
    main_to_product_drop_points, main_to_product_flat_points
        Optional explicit point counts for the connector segments.
    product_x_offset
        Horizontal spacing between multiple product-like states.
    overlay
        Overlay mode: ``"auto"``, ``"off"``, or ``"on"``.
    overlay_annotate
        Annotation mode for overlay profiles: ``"none"``, ``"energy"``, or
        ``"full"``.
    overlay_alpha, overlay_lw_scale
        Alpha and line-width scaling for overlay profiles.
    marker, overlay_markers
        Marker style for reference and overlay points.
    show_legend
        Whether to draw a legend.
    profile_label
        Legend label for a single profile.
    overlay_colors
        Optional overlay color mapping or sequence. A two-item tuple sets
        ``(main_color, side_color)``.
    same_energy_tol, same_energy_mode, same_energy_tag
        Controls for suppressing or tagging matching overlay energies.
    show_state_labels, state_label_rotation, state_label_pad
        X-axis state-label controls.
    font_size, state_label_fontsize, energy_fontsize, axis_label_fontsize,
    tick_label_fontsize, legend_fontsize, same_energy_tag_fontsize
        Font-size controls. Specific values override `font_size`.

    Returns
    -------
    tuple
        ``(fig, ax)`` for the Matplotlib figure and axes.
    """
    base_fontsize = 12.0 if font_size is None else float(font_size)
    state_label_fontsize = (
        base_fontsize
        if state_label_fontsize is None
        else float(state_label_fontsize)
    )
    energy_fontsize = (
        base_fontsize
        if energy_fontsize is None
        else float(energy_fontsize)
    )
    axis_label_fontsize = (
        base_fontsize
        if axis_label_fontsize is None
        else float(axis_label_fontsize)
    )
    tick_label_fontsize = (
        base_fontsize
        if tick_label_fontsize is None
        else float(tick_label_fontsize)
    )
    legend_fontsize = (
        base_fontsize
        if legend_fontsize is None
        else float(legend_fontsize)
    )
    same_energy_tag_fontsize = (
        energy_fontsize
        if same_energy_tag_fontsize is None
        else float(same_energy_tag_fontsize)
    )













    def _plot_one(
        profile_name,
        profile_states,
        ax_,
        is_reference,
        ref_x_map,
        ref_prod_xs,
        ref_energy_map,
        overlay_idx,
    ):
        (
            entries,
            seg_ids,
            side_anchor_label,
            side_connector_rise_frac,
            side_legend_label,
        ) = _parse_entries(
            profile_states
        )

        names = [e[0] for e in entries]
        E = np.array([e[1] for e in entries], dtype=float)

        if is_reference or not ref_x_map:
            x = _compute_x_single(entries, product_x_offset)
        else:
            x = _compute_x_from_reference(
                entries,
                ref_x_map,
                ref_prod_xs,
                product_x_offset,
            )

        profile_energy_map = _build_energy_map(entries)

        product_indices = [i for i, lab in enumerate(names) if _is_product(lab)]
        main_product_idx = product_indices[0] if product_indices else (len(entries) - 1)
        side_product_idx = (
            product_indices[1] if len(product_indices) >= 2 else main_product_idx
        )

        side_start_idx = None
        for i, sid in enumerate(seg_ids):
            if sid == 1:
                side_start_idx = i
                break

        main_color, side_color = _resolve_colors(
            overlay_colors,
            profile_name,
            is_reference,
            overlay_idx,
            side_start_idx is not None,
        )
        point_colors = [main_color] * len(entries)
        a = 1.0 if is_reference else float(overlay_alpha)
        lw = (1.5 * float(overlay_lw_scale)) if not is_reference else 1.5
        z_line = 5 if is_reference else 3
        z_scatter = 6 if is_reference else 4
        z_conn = 2.5 if is_reference else 2.0
        legend_marker = marker if is_reference else (
            overlay_markers.get(profile_name, marker)
            if isinstance(overlay_markers, dict)
            else marker
        )
        side_legend_meta = (
            {
                "profile_name": None if profile_name is None else str(profile_name),
                "label": str(side_legend_label) if side_legend_label is not None else None,
                "color": side_color,
                "alpha": a,
                "marker": legend_marker,
            }
            if side_start_idx is not None
            else None
        )

        if side_start_idx is None:
            x_i, E_i = _dedup_for_interp(x, E)
            xs = np.linspace(x_i.min(), x_i.max(), int(n_points))
            interp = PchipInterpolator(x_i, E_i)
            Es = interp(xs)

            ax_.plot(
                xs,
                Es,
                marker="",
                alpha=a,
                linewidth=lw,
                color=main_color,
                zorder=z_line
            )
            m = marker if is_reference else (
                overlay_markers.get(profile_name, marker)
                if isinstance(overlay_markers, dict)
                else marker
            )
            ax_.scatter(
                x,
                E,
                zorder=z_scatter,
                color=main_color,
                alpha=a,
                marker=m,
                s=30,
            )
        else:
            if side_start_idx == 0:
                raise ValueError(f"{side_token!r} cannot be the first entry.")

            main_end = side_start_idx - 1

            x_main = x[: main_end + 1]
            E_main = E[: main_end + 1]
            x_main_i, E_main_i = _dedup_for_interp(x_main, E_main)

            xs_main = np.linspace(
                x_main_i.min(), x_main_i.max(), max(2, int(n_points * 0.6))
            )
            interp_main = PchipInterpolator(x_main_i, E_main_i)
            Es_main = interp_main(xs_main)

            ax_.plot(
                xs_main,
                Es_main,
                marker="",
                alpha=a,
                linewidth=lw,
                color=main_color,
                zorder=z_line
            )
            m = marker if is_reference else (
                overlay_markers.get(profile_name, marker)
                if isinstance(overlay_markers, dict)
                else marker
            )            
            ax_.scatter(
                x_main,
                E_main,
                zorder=z_scatter,
                color=main_color,
                alpha=a,
                marker=m,
                s=30,
            )

            side_anchor_idx = main_end
            if side_anchor_label is not None:
                target = side_anchor_label.lower().strip()
                for j, (lab, _, _) in enumerate(entries):
                    if _norm_label(lab) == target:
                        side_anchor_idx = j
                        break
                else:
                    raise ValueError(
                        f"side-rxn anchor {side_anchor_label!r} not found among labels."
                    )

            if side_anchor_idx >= side_start_idx:
                raise ValueError(
                    f"side-rxn anchor {side_anchor_label!r} must be before side segment."
                )

            side_idxs = [i for i in range(side_start_idx, len(entries))]
            if main_product_idx in side_idxs and main_product_idx != side_product_idx:
                side_idxs = [i for i in side_idxs if i != main_product_idx]

            if side_product_idx not in side_idxs and side_product_idx >= side_start_idx:
                side_idxs.append(side_product_idx)
                side_idxs = sorted(set(side_idxs))

            for idx in side_idxs:
                point_colors[idx] = side_color
            point_colors[main_product_idx] = main_color

            x_side_main = x[side_idxs]
            E_side_main = E[side_idxs]
            x_side_i, E_side_i = _dedup_for_interp(x_side_main, E_side_main)

            xs_side = np.linspace(
                float(x_side_i.min()),
                float(x_side_i.max()),
                max(2, int(n_points * 0.6)),
            )
            interp_side = PchipInterpolator(x_side_i, E_side_i)
            Es_side = interp_side(xs_side)

            ax_.plot(
                xs_side,
                Es_side,
                marker="",
                alpha=a,
                linewidth=lw,
                color=side_color,
                zorder=z_line
            )
            m = marker if is_reference else (
                overlay_markers.get(profile_name, marker)
                if isinstance(overlay_markers, dict)
                else marker
            )            
            ax_.scatter(
                x_side_main,
                E_side_main,
                zorder=z_scatter,
                color=side_color,
                alpha=a,
                marker=m,
                s=30,
            )

            x0 = float(x[side_anchor_idx])
            y0 = float(E[side_anchor_idx])
            x1c = float(x[side_start_idx])
            y1c = float(E[side_start_idx])

            frac = (
                0.0
                if side_connector_rise_frac is None
                else float(side_connector_rise_frac)
            )
            frac = min(max(frac, 0.0), 1.0)

            x_rise = x0 + frac * (x1c - x0)

            xs_flat = np.linspace(x0, x_rise, 60, endpoint=False)
            ys_flat = np.full_like(xs_flat, y0, dtype=float)

            xs_rise = np.linspace(x_rise, x1c, 120)
            denom = (x1c - x_rise)
            if denom == 0:
                ys_rise = np.full_like(xs_rise, y1c, dtype=float)
            else:
                t = (xs_rise - x_rise) / denom
                t = np.clip(t, 0.0, 1.0)
                s = t * t * (3.0 - 2.0 * t)
                ys_rise = y0 + (y1c - y0) * s

            xs_conn = np.concatenate([xs_flat, xs_rise])
            ys_conn = np.concatenate([ys_flat, ys_rise])

            ax_.plot(
                xs_conn,
                ys_conn,
                linestyle=":",
                linewidth=3.0,
                alpha=a,
                marker="",
                color=side_color,
                zorder=z_conn
            )

            if show_main_to_product and len(x) >= 2:
                x0u = float(x[main_end])
                y0u = float(E[main_end])
                x1u = float(x[main_product_idx])
                y1u = float(E[main_product_idx])

                frac = min(max(float(main_to_product_drop_frac), 0.0), 1.0)
                x_drop = x0u + frac * (x1u - x0u)

                n_flat = (
                    int(main_to_product_flat_points)
                    if main_to_product_flat_points is not None
                    else max(20, int(n_points * 0.15))
                )
                n_drop = (
                    int(main_to_product_drop_points)
                    if main_to_product_drop_points is not None
                    else max(80, int(n_points * 0.35))
                )

                xs_flat = np.linspace(x0u, x_drop, max(2, n_flat), endpoint=False)
                ys_flat = np.full_like(xs_flat, y0u, dtype=float)

                xs_drop = np.linspace(x_drop, x1u, max(2, n_drop))
                denom = (x1u - x_drop)
                if denom == 0:
                    ys_drop = np.full_like(xs_drop, y1u, dtype=float)
                else:
                    t = (xs_drop - x_drop) / denom
                    t = np.clip(t, 0.0, 1.0)
                    s = t * t * (3.0 - 2.0 * t)
                    ys_drop = y0u + (y1u - y0u) * s

                xs_usual = np.concatenate([xs_flat, xs_drop])
                ys_usual = np.concatenate([ys_flat, ys_drop])

                mp_color = "C0" if is_reference else main_color

                ax_.plot(
                    xs_usual,
                    ys_usual,
                    linestyle=main_to_product_linestyle,
                    linewidth=main_to_product_lw,
                    alpha=main_to_product_alpha * a,
                    marker="",
                    color=mp_color,
                    zorder=z_conn
                )
                m = marker if is_reference else (
                    overlay_markers.get(profile_name, marker)
                    if isinstance(overlay_markers, dict)
                    else marker
                )                
                ax_.scatter(
                    [x1u],
                    [y1u],
                    zorder=z_scatter,
                    color=mp_color,
                    alpha=a,
                    marker=m,
                    s=30,
                )

        # --- Energy annotations (labels are handled on x-axis if enabled) ---
        # --- Annotations ---
        if is_reference:
            do_annotate = bool(annotate_energies)
        else:
            do_annotate = overlay_annotate in {"energy", "full"}

        if do_annotate:
            for i, (xi, Ei, label) in enumerate(zip(x, E, names), start=1):
                key = _norm_label(label)
                is_dummy = dummy_substr.lower() in key

                # keep your "same energy" suppression for overlays
                if not is_reference and ref_energy_map is not None:
                    ref_e = ref_energy_map.get(key)
                    if ref_e is not None and abs(float(Ei) - float(ref_e)) <= float(
                        same_energy_tol
                    ):
                        continue

                placement_counts = _parse_placement(entries[i - 1][2])
                if placement_counts is None:
                    is_int = key.startswith(int_prefix.lower())
                    if i == 1:
                        placement_counts = {"left": 1, "right": 0, "top": 0, "bottom": 0}
                    elif i == len(entries):
                        placement_counts = {"right": 1, "left": 0, "top": 0, "bottom": 0}
                    elif is_int:
                        placement_counts = {"bottom": 1, "top": 0, "left": 0, "right": 0}
                    else:
                        placement_counts = {"top": 1, "bottom": 0, "left": 0, "right": 0}

                txt_color = point_colors[i - 1]
                alpha = 1.0 if is_reference else float(overlay_alpha)

                if not multi:
                    # SINGLE-MOLECULE: restore original label+energy annotations
                    top_n = placement_counts["top"]
                    bottom_n = placement_counts["bottom"]
                    left_n = placement_counts["left"]
                    right_n = placement_counts["right"]

                    dx = 0
                    dy = 0
                    ha = "center"
                    va = "center"

                    if left_n:
                        dx = -12 * left_n
                        ha = "right"
                    elif right_n:
                        dx = 12 * right_n
                        ha = "left"

                    if top_n:
                        dy = abs(label_offset_up) * top_n
                        va = "bottom"
                    elif bottom_n:
                        dy = -abs(label_offset_down) * bottom_n
                        va = "top"

                    add_arrow = max(top_n, bottom_n, left_n, right_n) > 1

                    if annotate_energies:
                        text = f"{label}\n{Ei:.{decimals}f}"
                    else:
                        text = f"{label}"

                    a = (dummy_alpha if is_dummy else 1.0) * alpha

                    arrowprops = None
                    if add_arrow:
                        arrowprops = {
                            "arrowstyle": "->",
                            "lw": 0.8,
                            "alpha": a * 0.8,
                            "shrinkA": 0,
                            "shrinkB": 6,
                            "mutation_scale": 8,
                        }

                    ax_.annotate(
                        text,
                        (xi, Ei),
                        textcoords="offset points",
                        xytext=(dx, dy),
                        ha=ha,
                        va=va,
                        alpha=a,
                        arrowprops=arrowprops,
                        color=txt_color,
                        fontsize=energy_fontsize,
                    )
                else:
                    # MULTI-MOLECULE: energy-only (state names are on x-axis)
                    _annotate_energy_only(
                        ax_=ax_,
                        xi=float(xi),
                        Ei=float(Ei),
                        alpha=alpha,
                        color=txt_color,
                        placement_counts=placement_counts,
                        is_dummy=is_dummy,
                        decimals=decimals,
                        label_offset_up=label_offset_up,
                        label_offset_down=label_offset_down,
                        dummy_alpha=dummy_alpha,
                        energy_fontsize=energy_fontsize,
                    )

        if is_reference:
            x_map = {}
            prod_xs = []
            ordered = []
            for xi, lab in zip(x, names):
                k = _norm_label(lab)
                x_map[k] = float(xi)
                ordered.append((float(xi), str(lab)))
                if _is_product(lab):
                    prod_xs.append(float(xi))
            return x_map, prod_xs, profile_energy_map, ordered, side_legend_meta

        return None, None, profile_energy_map, None, side_legend_meta

    # ---- Detect multi-profile input (no breaking of current list input) ----
    multi = False
    profiles = None

    if isinstance(states, dict):
        profiles = list(states.items())
        multi = True
    elif isinstance(states, (list, tuple)) and states:
        first = states[0]
        if (
            isinstance(first, (list, tuple))
            and len(first) == 2
            and isinstance(first[0], str)
            and isinstance(first[1], (list, tuple))
        ):
            profiles = list(states)
            multi = True

    if overlay == "off":
        multi = False
        profiles = None
    elif overlay == "on":
        if not multi:
            raise ValueError("overlay='on' requires dict or list-of-(name, states).")

    if show_state_labels is None:
        show_state_labels = bool(multi)

    created_fig = False
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
        created_fig = True
    else:
        fig = ax.figure

    ref_x_map = {}
    ref_prod_xs = []
    ref_energy_map = None
    ref_ordered = None
    side_legend_metas = []

    if not multi:
        _, _, _, _, side_meta = _plot_one(
            profile_name=None,
            profile_states=states,
            ax_=ax,
            is_reference=True,
            ref_x_map=ref_x_map,
            ref_prod_xs=ref_prod_xs,
            ref_energy_map=None,
            overlay_idx=0,
        )
        if side_meta is not None:
            side_legend_metas.append(side_meta)
    else:
        ref_name, ref_states = profiles[0]
        ref_x_map, ref_prod_xs, ref_energy_map, ref_ordered, side_meta = _plot_one(
            profile_name=ref_name,
            profile_states=ref_states,
            ax_=ax,
            is_reference=True,
            ref_x_map=ref_x_map,
            ref_prod_xs=ref_prod_xs,
            ref_energy_map=None,
            overlay_idx=0,
        )
        if side_meta is not None:
            side_legend_metas.append(side_meta)

        overlay_energy_maps = []
        for k, (name, st) in enumerate(profiles[1:], start=0):
            _, _, e_map, _, side_meta = _plot_one(
                profile_name=name,
                profile_states=st,
                ax_=ax,
                is_reference=False,
                ref_x_map=ref_x_map,
                ref_prod_xs=ref_prod_xs,
                ref_energy_map=ref_energy_map,
                overlay_idx=k,
            )
            overlay_energy_maps.append(e_map)
            if side_meta is not None:
                side_legend_metas.append(side_meta)

        if same_energy_mode == "tag" and annotate_energies and ref_energy_map is not None:
            for key, ref_e in ref_energy_map.items():
                matched = False
                for om in overlay_energy_maps:
                    oe = om.get(key) if om is not None else None
                    if oe is None:
                        continue
                    if abs(float(oe) - float(ref_e)) <= float(same_energy_tol):
                        matched = True
                        break
                if not matched:
                    continue

                xi = float(ref_x_map[key])
                yi = float(ref_e)
                ax.annotate(
                    same_energy_tag,
                    (xi, yi),
                    textcoords="offset points",
                    xytext=(8, 0),
                    ha="left",
                    va="center",
                    alpha=1.0,
                    color="C0",
                    fontsize=same_energy_tag_fontsize,
                )

        if show_legend:
            handles = []
            labels = []
            for i, (name, _) in enumerate(profiles):
                if name is None:
                    continue
                if i == 0:
                    color = "C0"
                    a = 1.0
                else:
                    if isinstance(overlay_colors, dict) and name in overlay_colors:
                        spec = overlay_colors[name]
                        color = spec[0] if isinstance(spec, (tuple, list)) else spec
                    else:
                        color = f"C{i}"
                    a = overlay_alpha
                if i == 0:
                    m = marker
                else:
                    if isinstance(overlay_markers, dict):
                        m = overlay_markers.get(name, marker)
                    else:
                        m = marker

                h = plt.Line2D(
                    [0],
                    [0],
                    color=color,
                    alpha=a,
                    marker=m,
                    linestyle="-",
                )
                handles.append(h)
                labels.append(str(name))
            for meta in side_legend_metas:
                label = meta["label"]
                if label is None:
                    continue

                style_meta = _style_meta_for_side_label(
                    label,
                    meta,
                    side_legend_metas,
                )
                h = plt.Line2D(
                    [0],
                    [0],
                    color=style_meta["color"],
                    alpha=style_meta["alpha"],
                    marker=style_meta["marker"],
                    linestyle="-",
                )
                handles.append(h)
                labels.append(label)
            if handles:
                ax.legend(handles, labels, frameon=False, fontsize=legend_fontsize)

    if not multi and show_legend:
        handles = []
        labels = []

        if profile_label is not None:
            h = plt.Line2D(
                [0],
                [0],
                color="C0",
                alpha=1.0,
                marker=marker,
                linestyle="-",
            )
            handles.append(h)
            labels.append(str(profile_label))

        for meta in side_legend_metas:
            label = meta["label"]
            if label is None:
                continue

            h = plt.Line2D(
                [0],
                [0],
                color=meta["color"],
                alpha=meta["alpha"],
                marker=meta["marker"],
                linestyle="-",
            )
            handles.append(h)
            labels.append(label)

        if handles:
            ax.legend(handles, labels, frameon=False, fontsize=legend_fontsize)

    ax.set_ylabel(ylabel, fontsize=axis_label_fontsize)

    # --- Bottom labels (states) ---
    if show_state_labels:
        # Use reference ordering if available (multi); otherwise derive from single.
        if ref_ordered is None:
            entries, _, _, _, _ = _parse_entries(states)
            x_single = _compute_x_single(entries, product_x_offset)
            ref_ordered = [(float(xi), str(lab)) for xi, (lab, _, _) in zip(x_single, entries)]

        # If there are duplicated x (multiple products), matplotlib will still
        # accept them; labels may overlap, but the offsets typically separate them.
        xs = [p[0] for p in ref_ordered]
        labs = [p[1] for p in ref_ordered]

        ax.set_xticks(xs)
        ax.set_xticklabels(labs, rotation=state_label_rotation,
                           fontsize=state_label_fontsize)
        ax.tick_params(axis="x", pad=state_label_pad)

        hide_x_ticks = False

    # --- Limits ---
    if ref_x_map:
        xmin = min(ref_x_map.values())
        xmax = max(ref_x_map.values())
    else:
        xmin, xmax = ax.get_xlim()

    left_pad = 1.05
    right_pad = 0.8
    ax.set_xlim(xmin - left_pad, xmax + right_pad)

    ax.grid(bool(grid))
    ax.set_facecolor("white")

    if hide_x_ticks:
        ax.set_xticks([])
    elif not show_state_labels:
        ax.tick_params(axis="x", labelsize=tick_label_fontsize)

    if hide_y_ticks:
        ax.set_yticks([])
    else:
        ax.tick_params(axis="y", labelsize=tick_label_fontsize)

    if hide_spines:
        for spine in ax.spines.values():
            spine.set_visible(False)

    if created_fig:
        fig.tight_layout()

    return fig, ax

plot_lig

plot_lig(df: DataFrame, substrate_names: Union[str, List[str]], exclude_coords: Optional[List[str]] = None, coord_indices: Optional[Union[List[int], slice]] = slice(-1, None), **kwargs: Any) -> None

Display molecules filtered by substrate name(s).

Source code in frust/vis/molecules.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def plot_lig(
    df: pd.DataFrame,
    substrate_names: Union[str, List[str]],
    exclude_coords: Optional[List[str]] = None,
    coord_indices: Optional[Union[List[int], slice]] = slice(-1, None),
    **kwargs: Any
) -> None:
    """Display molecules filtered by substrate name(s)."""
    if isinstance(substrate_names, str):
        substrate_names = [substrate_names]

    plot_mols(
        df,
        substrate_filter=substrate_names,
        exclude_coords=exclude_coords,
        coord_indices=coord_indices,
        **kwargs
    )

plot_mols

plot_mols(df: DataFrame, row_indices: Optional[List[int]] = None, substrate_filter: Optional[List[str]] = None, rpos_filter: Optional[List[Union[str, int]]] = None, exclude_coords: Optional[List[str]] = None, include_coords: Optional[List[str]] = None, coord_indices: Optional[Union[List[int], slice]] = slice(-1, None), dark: bool = False, **molto3d_kwargs: Any) -> None

Display molecules from a dataframe with filtering capabilities.

Args: df: DataFrame with molecular data row_indices: List of row indices to display (if None, displays all rows) substrate_filter: List of substrate names to include (if None, includes all) rpos_filter: List of rpos values to include (if None, includes all) exclude_coords: List of coordinate column patterns to exclude include_coords: List of coordinate column patterns to include (overrides exclude) coord_indices: List of indices or a slice for coordinate columns (overrides include/exclude). dark: If True, use a dark background by default. Ignored if 'background_color' is explicitly provided in molto3d_kwargs. **molto3d_kwargs: Additional arguments to pass to MolTo3DGrid

Returns: None

Source code in frust/vis/molecules.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def plot_mols(
    df: pd.DataFrame,
    row_indices: Optional[List[int]] = None,
    substrate_filter: Optional[List[str]] = None,
    rpos_filter: Optional[List[Union[str, int]]] = None,
    exclude_coords: Optional[List[str]] = None,
    include_coords: Optional[List[str]] = None,
    coord_indices: Optional[Union[List[int], slice]] = slice(-1, None),
    dark: bool = False,
    **molto3d_kwargs: Any
) -> None:
    """
    Display molecules from a dataframe with filtering capabilities.

    Args:
        df: DataFrame with molecular data
        row_indices: List of row indices to display (if None, displays all
            rows)
        substrate_filter: List of substrate names to include (if None, includes all)
        rpos_filter: List of rpos values to include (if None, includes all)
        exclude_coords: List of coordinate column patterns to exclude
        include_coords: List of coordinate column patterns to include
            (overrides exclude)
        coord_indices: List of indices or a slice for coordinate columns
            (overrides include/exclude).
        dark: If True, use a dark background by default. Ignored if
            'background_color' is explicitly provided in molto3d_kwargs.
        **molto3d_kwargs: Additional arguments to pass to MolTo3DGrid

    Returns:
        None
    """
    filtered_df = normalize_dataframe(df)

    if row_indices is not None:
        filtered_df = filtered_df.iloc[row_indices]

    if substrate_filter is not None:
        filtered_df = filtered_df[
            filtered_df['substrate_name'].isin(substrate_filter)
        ]

    if rpos_filter is not None:
        filtered_df = filtered_df[filtered_df['rpos'].isin(rpos_filter)]

    if filtered_df.empty:
        print("No molecules match the specified filters.")
        return

    coord_columns = [
        c for c in filtered_df.columns
        if "coords" in c or str(c).endswith("-oc") or str(c).endswith("-opt_coords")
    ]

    if coord_indices is not None:
        if isinstance(coord_indices, slice):
            coord_columns = coord_columns[coord_indices]
        else:
            coord_columns = [
                coord_columns[i] for i in coord_indices
                if 0 <= i < len(coord_columns)
            ]
    elif include_coords is not None:
        coord_columns = [
            c for c in coord_columns
            if any(pattern in c for pattern in include_coords)
        ]
    elif exclude_coords is not None:
        coord_columns = [
            c for c in coord_columns
            if not any(pattern in c for pattern in exclude_coords)
        ]

    if not coord_columns:
        print("No coordinate columns found after filtering.")
        return

    print(f"Found {len(coord_columns)} coordinate columns: {coord_columns}")
    print(f"Processing {len(filtered_df)} rows")

    all_mols = []
    all_legends = []

    for idx, row in filtered_df.iterrows():
        atoms = row["atoms"]
        substrate_name = row["substrate_name"]
        rpos = row["rpos"]

        for coord_col in coord_columns:
            coords = row[coord_col]

            if coords is not None:
                if isinstance(coords, np.ndarray):
                    is_valid = coords.size > 0 and not pd.isna(coords).all()
                else:
                    is_valid = (not pd.isna(coords)
                                if not isinstance(coords, list)
                                else len(coords) > 0)

                if is_valid:
                    mol = ac2mol(atoms, coords)
                    if mol is not None:
                        all_mols.append(mol)

                        coord_type = (coord_col.replace("coords_", "")
                                      .replace("_coords", ""))
                        if rpos is None or pd.isna(rpos):
                            legend = f"{substrate_name}\n{coord_type}"
                        else:
                            legend = f"{substrate_name} r{rpos}\n{coord_type}"
                        all_legends.append(legend)

    if not all_mols:
        print("No valid molecules could be generated.")
        return

    print(f"Generated {len(all_mols)} molecules for display")

    # Defaults (can be overridden by molto3d_kwargs)
    molto3d_args = {
        'legends': all_legends,
        'show_labels': False,
        'show_confs': True,
        #'background_color': 'black' if darkmode else 'white',
        'cell_size': (400, 400),
        'columns': len(coord_columns) if coord_indices is None else 4,
        'linked': False,
        'kekulize': True,
        'show_charges': True,
    }

    molto3d_args.update(molto3d_kwargs)

    MolTo3DGrid(all_mols, **molto3d_args)

plot_regression_outliers

plot_regression_outliers(df: DataFrame, x_col: str = 'dE', y_col: str = 'dG', xlabel: str = '$\\Delta$E, kcal/mol', ylabel: str = '$\\Delta$G, kcal/mol', font_size: int = 14, label_col: str = 'substrate_name', rpos_col: str = 'rpos', method: str = 'spearman', num_outliers: int = 2, size: tuple = (8, 6), plot_1x: bool = False, equal_axis: bool = False, regression_text: str = 'legend', regression_text_loc: Union[str, Tuple[float, float]] = 'best', rmsd_unit: str = 'kcal/mol') -> pd.DataFrame

Plot x vs y with linear fit, score outliers, and annotate top points.

Args: df (pd.DataFrame): Input data. x_col (str): Name of the column to use for x values. Defaults to "dG". y_col (str): Name of the column to use for y values. Defaults to "dE". label_col (str): Column used for point labels. Defaults to "substrate_name". rpos_col (str): Column used for position annotations. Defaults to "rpos". method (str, optional): Scoring method, "pearson" or "spearman". Defaults to "spearman". num_outliers (int, optional): Number of top outliers to annotate. Defaults to 2. regression_text (str, optional): Where to place regression statistics. Use "legend" for the historical behavior, "plot" to place them inside the axes, "both" for both locations, or "none" to hide them. Defaults to "legend". regression_text_loc (str or tuple, optional): Location for in-plot regression text. Named locations are "best", "upper left", "upper right", "lower left", and "lower right". A tuple is interpreted as axes-fraction coordinates. Defaults to "best". rmsd_unit (str, optional): Unit displayed after RMSD values. Use an empty string to omit the unit. Defaults to "kcal/mol".

Returns: pd.DataFrame: DataFrame of the top outliers sorted by score.

Source code in frust/vis/regression.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def plot_regression_outliers(
    df: pd.DataFrame,
    x_col: str = "dE",
    y_col: str = "dG",
    xlabel: str = r"$\Delta$E, kcal/mol",
    ylabel: str = r"$\Delta$G, kcal/mol",
    font_size: int = 14,
    label_col: str = "substrate_name",
    rpos_col: str = "rpos",
    method: str = "spearman",
    num_outliers: int = 2,
    size: tuple = (8, 6),
    plot_1x: bool = False,
    equal_axis: bool = False,
    regression_text: str = "legend",
    regression_text_loc: Union[str, Tuple[float, float]] = "best",
    rmsd_unit: str = "kcal/mol",
) -> pd.DataFrame:
    """Plot x vs y with linear fit, score outliers, and annotate top points.

    Args:
        df (pd.DataFrame): Input data.
        x_col (str): Name of the column to use for x values. Defaults to "dG".
        y_col (str): Name of the column to use for y values. Defaults to "dE".
        label_col (str): Column used for point labels. Defaults to
            "substrate_name".
        rpos_col (str): Column used for position annotations. Defaults to
            "rpos".
        method (str, optional): Scoring method, "pearson" or "spearman".
            Defaults to "spearman".
        num_outliers (int, optional): Number of top outliers to annotate.
            Defaults to 2.
        regression_text (str, optional): Where to place regression statistics.
            Use "legend" for the historical behavior, "plot" to place them
            inside the axes, "both" for both locations, or "none" to hide
            them. Defaults to "legend".
        regression_text_loc (str or tuple, optional): Location for in-plot
            regression text. Named locations are "best", "upper left",
            "upper right", "lower left", and "lower right". A tuple is
            interpreted as axes-fraction coordinates. Defaults to "best".
        rmsd_unit (str, optional): Unit displayed after RMSD values. Use an
            empty string to omit the unit. Defaults to "kcal/mol".

    Returns:
        pd.DataFrame: DataFrame of the top outliers sorted by score.
    """
    for col in (x_col, y_col, label_col, rpos_col):
        if col not in df.columns:
            raise ValueError(f"Column not found: {col}")
    if method not in ("pearson", "spearman"):
        raise ValueError(f"Invalid method: {method}")
    if regression_text not in ("legend", "plot", "both", "none"):
        raise ValueError(f"Invalid regression_text: {regression_text}")

    data = df.copy()
    x = data[x_col]
    y = data[y_col]

    lr = linregress(x, y)
    y_fit = lr.slope * x + lr.intercept
    rho, _ = spearmanr(x, y)

    c = float(np.mean(y - x))
    y_hat = x + c

    # Metrics
    y_arr = np.asarray(y, dtype=float)
    yfit_arr = np.asarray(y_fit, dtype=float)
    yhat_arr = np.asarray(y_hat, dtype=float)

    rmsd_fit = float(np.sqrt(np.mean((y_arr - yfit_arr) ** 2)))
    rmsd_hat = float(np.sqrt(np.mean((y_arr - yhat_arr) ** 2)))

    sst = float(np.sum((y_arr - np.mean(y_arr)) ** 2))
    sse_hat = float(np.sum((y_arr - yhat_arr) ** 2))
    r2_hat = 1.0 - (sse_hat / sst) if sst > 0 else np.nan

    rho_hat, _ = spearmanr(y_hat, y)

    # Print equations to stdout (not on the plot)
    eq_label = (f"y = {lr.slope:.2f}x "
                f"{'+' if lr.intercept >= 0 else '-'} "
                f"{abs(lr.intercept):.2f}")
    print("[INFO]: Linear relation:", eq_label)
    eq2_label = (f"y = 1x "
                 f"{'+' if c >= 0 else '-'} "
                 f"{abs(c):.2f}")
    print("[INFO]: Error relationship: ", eq2_label)

    rmsd_unit_suffix = f" {rmsd_unit}" if rmsd_unit else ""
    fit_stats_label = (f"$R^2$={lr.rvalue**2:.3f}, "
                       f"spearman={rho:.3f}, "
                       f"RMSD={rmsd_fit:.3f}{rmsd_unit_suffix}")
    offset_stats_label = (f"$R^2$={r2_hat:.3f}, "
                          f"spearman={rho_hat:.3f}, "
                          f"RMSD={rmsd_hat:.3f}{rmsd_unit_suffix}")

    if method == "pearson":
        data["score"] = (y - y_fit).abs()
    else:
        data["rank_x"] = x.rank()
        data["rank_y"] = y.rank()
        data["score"] = (data["rank_y"] - data["rank_x"]).abs()

    outliers = data.nlargest(num_outliers, "score")

    style_ctx = (plt.style.context('dark_background')
                 if theme.darkmode else nullcontext())
    with style_ctx:
        fig, ax = plt.subplots(figsize=size)
        ax.scatter(x, y, alpha=0.7)
        ax.plot(
            x, y_fit, color="red", marker="",
            label=fit_stats_label
            if regression_text in ("legend", "both") else "linear fit"
        )
        if plot_1x:
            ax.plot(
                x, y_hat, marker="",
                label=offset_stats_label
                if regression_text in ("legend", "both") else "1:1 offset"
            )

        for _, row in outliers.iterrows():
            label = f"{row[label_col]}-r{int(row[rpos_col])}"
            ax.annotate(
                label,
                (row[x_col], row[y_col]),
                textcoords="offset points",
                xytext=(5, 5),
                ha="left",
                fontsize=8,
                arrowprops=dict(arrowstyle="->", lw=0.5)
            )
        ax.set_xlabel(xlabel, fontsize=font_size)
        ax.set_ylabel(ylabel, fontsize=font_size)
        ax.tick_params(axis="both", labelsize=font_size)
        ax.legend(fontsize=font_size)
        ax.grid(True)
        if equal_axis:
            xmin = min(x.min(), y.min())
            xmax = max(x.max(), y.max())

            ax.set_xlim(xmin, xmax)
            ax.set_ylim(xmin, xmax)
            ax.set_aspect("equal", adjustable="box")

        if regression_text in ("plot", "both"):
            text_lines = [
                f"Fit: {eq_label}",
                f"$R^2$={lr.rvalue**2:.3f}, spearman={rho:.3f}",
                f"RMSD={rmsd_fit:.3f}{rmsd_unit_suffix}",
            ]
            if plot_1x:
                text_lines.extend([
                    f"Offset fit: {eq2_label}",
                    f"$R^2$={r2_hat:.3f}, spearman={rho_hat:.3f}",
                    f"RMSD={rmsd_hat:.3f}{rmsd_unit_suffix}",
                ])

            if isinstance(regression_text_loc, tuple):
                x_text, y_text = regression_text_loc
                ha = "left" if x_text <= 0.5 else "right"
                va = "bottom" if y_text <= 0.5 else "top"
            else:
                loc_lookup = {
                    "upper left": (0.03, 0.97, "left", "top"),
                    "upper right": (0.97, 0.97, "right", "top"),
                    "lower left": (0.03, 0.03, "left", "bottom"),
                    "lower right": (0.97, 0.03, "right", "bottom"),
                }
                loc = regression_text_loc
                if loc == "best":
                    loc = _least_crowded_text_loc(x, y)
                if loc not in loc_lookup:
                    raise ValueError(
                        f"Invalid regression_text_loc: {regression_text_loc}"
                    )
                x_text, y_text, ha, va = loc_lookup[loc]

            ax.text(
                x_text, y_text, "\n".join(text_lines),
                transform=ax.transAxes,
                ha=ha,
                va=va,
                fontsize=max(font_size - 2, 8),
                bbox={
                    "boxstyle": "round,pad=0.35",
                    "facecolor": "black" if theme.darkmode else "white",
                    "edgecolor": "0.5",
                    "alpha": 0.85,
                },
            )

        fig.tight_layout()
        plt.show()

    return None

plot_row

plot_row(df: DataFrame, row_index: int = 0, exclude_coords: Optional[List[str]] = None, coord_indices: Optional[Union[List[int], slice]] = slice(-1, None), **kwargs: Any) -> None

Display all coordinate types for a single row.

Source code in frust/vis/molecules.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def plot_row(
    df: pd.DataFrame,
    row_index: int = 0,
    exclude_coords: Optional[List[str]] = None,
    coord_indices: Optional[Union[List[int], slice]] = slice(-1, None),
    **kwargs: Any
) -> None:
    """Display all coordinate types for a single row."""
    plot_mols(
        df,
        row_indices=[row_index],
        exclude_coords=exclude_coords,
        coord_indices=coord_indices,
        **kwargs
    )

plot_rpos

plot_rpos(df: DataFrame, rpos_values: Union[str, int, List[Union[str, int]]], exclude_coords: Optional[List[str]] = None, coord_indices: Optional[Union[List[int], slice]] = slice(-1, None), **kwargs: Any) -> None

Display molecules filtered by rpos value(s).

Source code in frust/vis/molecules.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def plot_rpos(
    df: pd.DataFrame,
    rpos_values: Union[str, int, List[Union[str, int]]],
    exclude_coords: Optional[List[str]] = None,
    coord_indices: Optional[Union[List[int], slice]] = slice(-1, None),
    **kwargs: Any
) -> None:
    """Display molecules filtered by rpos value(s)."""
    if isinstance(rpos_values, (str, int)):
        rpos_values = [rpos_values]

    plot_mols(
        df,
        rpos_filter=rpos_values,
        exclude_coords=exclude_coords,
        coord_indices=coord_indices,
        **kwargs
    )

set_theme

set_theme(dark: bool = True) -> None

Set module-wide visualization theme.

Parameters:

Name Type Description Default
dark bool

If True, helpers that support a dark background render with dark mode defaults.

True
Source code in frust/vis/theme.py
 6
 7
 8
 9
10
11
12
13
14
15
16
def set_theme(dark: bool = True) -> None:
    """Set module-wide visualization theme.

    Parameters
    ----------
    dark
        If ``True``, helpers that support a dark background render with dark
        mode defaults.
    """
    global darkmode
    darkmode = dark

use_darkmode

use_darkmode(on: bool = True)

Temporarily enable dark mode within a context manager.

Parameters:

Name Type Description Default
on bool

Whether dark mode should be active inside the context.

True

Yields:

Type Description
None

Control returns to the caller with the previous setting restored when the context exits.

Source code in frust/vis/theme.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
@contextmanager
def use_darkmode(on: bool = True):
    """Temporarily enable dark mode within a context manager.

    Parameters
    ----------
    on
        Whether dark mode should be active inside the context.

    Yields
    ------
    None
        Control returns to the caller with the previous setting restored when
        the context exits.
    """
    global darkmode
    prev = darkmode
    darkmode = on
    try:
        yield
    finally:
        darkmode = prev

Energy Profile Internals

energy_profile

plot_energy_profile

plot_energy_profile(states, ylabel: str = 'ΔG (kcal/mol)', n_points: int = 500, figsize=(8, 3.5), annotate_energies: bool = True, decimals: int = 1, int_prefix: str = 'int', label_offset_up: int = 8, label_offset_down: int = 12, hide_y_ticks: bool = True, hide_x_ticks: bool = True, hide_spines: bool = True, grid: bool = False, ax=None, dummy_substr: str = 'dummy', dummy_alpha: float = 0.5, side_token: str = 'side-rxn', show_main_to_product: bool = True, main_to_product_alpha: float = 1, main_to_product_linestyle: str = ':', main_to_product_lw: float = 3.0, main_to_product_bow: float = 0.5, main_to_product_drop_frac: float = 0.65, main_to_product_drop_points: int | None = None, main_to_product_flat_points: int | None = None, product_x_offset: float = 0.18, overlay: str = 'auto', overlay_annotate: str = 'energy', overlay_alpha: float = 0.35, overlay_lw_scale: float = 1.0, marker: str = 'o', overlay_markers=None, show_legend: bool = True, profile_label: str | None = None, overlay_colors=None, same_energy_tol: float = 0.001, same_energy_mode: str = 'hide', same_energy_tag: str = '≡', show_state_labels: bool | None = None, state_label_rotation: float = 0.0, font_size: float | None = None, state_label_fontsize: float | None = None, energy_fontsize: float | None = None, axis_label_fontsize: float | None = None, tick_label_fontsize: float | None = None, legend_fontsize: float | None = None, same_energy_tag_fontsize: float | None = None, state_label_pad: float = 6.0)

Plot one or more reaction energy profiles.

Parameters:

Name Type Description Default
states

Single profile as a sequence of (label, energy[, placement]) entries, or multiple profiles as a mapping/list of (profile_name, states). A string marker such as "side-rxn@int2@0.6#Side product" starts a side-reaction segment.

required
ylabel str

Y-axis label.

'ΔG (kcal/mol)'
n_points int

Number of interpolation points used for smooth profile curves.

500
figsize

Figure size used when ax is not provided.

(8, 3.5)
annotate_energies bool

Whether to annotate energies for the reference profile.

True
decimals int

Number of decimal places shown in energy labels.

1
int_prefix str

Label prefix treated as an intermediate for default label placement.

'int'
label_offset_up int

Point offsets used for top and bottom annotations.

8
label_offset_down int

Point offsets used for top and bottom annotations.

8
hide_y_ticks bool

Axis cleanup options.

True
hide_x_ticks bool

Axis cleanup options.

True
hide_spines bool

Axis cleanup options.

True
grid bool

Whether to show the Matplotlib grid.

False
ax

Existing Matplotlib axes. If omitted, a new figure and axes are created.

None
dummy_substr str

Substring used to detect dummy states that should render with reduced annotation alpha.

'dummy'
dummy_alpha float

Alpha multiplier for dummy-state annotations.

0.5
side_token str

String token that starts side-reaction parsing.

'side-rxn'
show_main_to_product bool

Whether to draw the dotted main-path connection to the product after a side-reaction branch.

True
main_to_product_alpha float

Style controls for the main-to-product connector.

1
main_to_product_linestyle float

Style controls for the main-to-product connector.

1
main_to_product_lw float

Style controls for the main-to-product connector.

1
main_to_product_drop_frac float

Fraction of the connector x-distance kept flat before dropping to the product energy.

0.65
main_to_product_drop_points int | None

Optional explicit point counts for the connector segments.

None
main_to_product_flat_points int | None

Optional explicit point counts for the connector segments.

None
product_x_offset float

Horizontal spacing between multiple product-like states.

0.18
overlay str

Overlay mode: "auto", "off", or "on".

'auto'
overlay_annotate str

Annotation mode for overlay profiles: "none", "energy", or "full".

'energy'
overlay_alpha float

Alpha and line-width scaling for overlay profiles.

0.35
overlay_lw_scale float

Alpha and line-width scaling for overlay profiles.

0.35
marker str

Marker style for reference and overlay points.

'o'
overlay_markers str

Marker style for reference and overlay points.

'o'
show_legend bool

Whether to draw a legend.

True
profile_label str | None

Legend label for a single profile.

None
overlay_colors

Optional overlay color mapping or sequence. A two-item tuple sets (main_color, side_color).

None
same_energy_tol float

Controls for suppressing or tagging matching overlay energies.

0.001
same_energy_mode float

Controls for suppressing or tagging matching overlay energies.

0.001
same_energy_tag float

Controls for suppressing or tagging matching overlay energies.

0.001
show_state_labels bool | None

X-axis state-label controls.

None
state_label_rotation bool | None

X-axis state-label controls.

None
state_label_pad bool | None

X-axis state-label controls.

None
font_size float | None
None
state_label_fontsize float | None
None
energy_fontsize float | None
None
axis_label_fontsize float | None
None
tick_label_fontsize float | None

Font-size controls. Specific values override font_size.

None
legend_fontsize float | None

Font-size controls. Specific values override font_size.

None
same_energy_tag_fontsize float | None

Font-size controls. Specific values override font_size.

None

Returns:

Type Description
tuple

(fig, ax) for the Matplotlib figure and axes.

Source code in frust/vis/energy_profile/api.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
def plot_energy_profile(
    states,
    ylabel: str = "ΔG (kcal/mol)",
    n_points: int = 500,
    figsize=(8, 3.5),
    annotate_energies: bool = True,
    decimals: int = 1,
    int_prefix: str = "int",
    label_offset_up: int = 8,
    label_offset_down: int = 12,
    hide_y_ticks: bool = True,
    hide_x_ticks: bool = True,
    hide_spines: bool = True,
    grid: bool = False,
    ax=None,
    dummy_substr: str = "dummy",
    dummy_alpha: float = 0.5,
    side_token: str = "side-rxn",
    show_main_to_product: bool = True,
    main_to_product_alpha: float = 1,
    main_to_product_linestyle: str = ":",
    main_to_product_lw: float = 3.0,
    main_to_product_bow: float = 0.5,
    main_to_product_drop_frac: float = 0.65,
    main_to_product_drop_points: int | None = None,
    main_to_product_flat_points: int | None = None,
    product_x_offset: float = 0.18,
    # --- multi-molecule overlays ---
    overlay: str = "auto",  # "auto" | "off" | "on"
    overlay_annotate: str = "energy",  # "none" | "energy" | "full"
    overlay_alpha: float = 0.35,
    overlay_lw_scale: float = 1.0,
    marker: str = "o",
    overlay_markers=None,
    show_legend: bool = True,
    profile_label: str | None = None,
    overlay_colors=None,
    same_energy_tol: float = 1e-3,
    same_energy_mode: str = "hide",  # "hide" | "tag"
    same_energy_tag: str = "≡",
    # --- NEW: bottom state labels (recommended for overlays) ---
    show_state_labels: bool | None = None,
    state_label_rotation: float = 0.0,
    font_size: float | None = None,
    state_label_fontsize: float | None = None,
    energy_fontsize: float | None = None,
    axis_label_fontsize: float | None = None,
    tick_label_fontsize: float | None = None,
    legend_fontsize: float | None = None,
    same_energy_tag_fontsize: float | None = None,
    state_label_pad: float = 6.0,
):
    """Plot one or more reaction energy profiles.

    Parameters
    ----------
    states
        Single profile as a sequence of ``(label, energy[, placement])`` entries,
        or multiple profiles as a mapping/list of ``(profile_name, states)``.
        A string marker such as ``"side-rxn@int2@0.6#Side product"`` starts a
        side-reaction segment.
    ylabel
        Y-axis label.
    n_points
        Number of interpolation points used for smooth profile curves.
    figsize
        Figure size used when `ax` is not provided.
    annotate_energies
        Whether to annotate energies for the reference profile.
    decimals
        Number of decimal places shown in energy labels.
    int_prefix
        Label prefix treated as an intermediate for default label placement.
    label_offset_up, label_offset_down
        Point offsets used for top and bottom annotations.
    hide_y_ticks, hide_x_ticks, hide_spines
        Axis cleanup options.
    grid
        Whether to show the Matplotlib grid.
    ax
        Existing Matplotlib axes. If omitted, a new figure and axes are created.
    dummy_substr
        Substring used to detect dummy states that should render with reduced
        annotation alpha.
    dummy_alpha
        Alpha multiplier for dummy-state annotations.
    side_token
        String token that starts side-reaction parsing.
    show_main_to_product
        Whether to draw the dotted main-path connection to the product after a
        side-reaction branch.
    main_to_product_alpha, main_to_product_linestyle, main_to_product_lw
        Style controls for the main-to-product connector.
    main_to_product_drop_frac
        Fraction of the connector x-distance kept flat before dropping to the
        product energy.
    main_to_product_drop_points, main_to_product_flat_points
        Optional explicit point counts for the connector segments.
    product_x_offset
        Horizontal spacing between multiple product-like states.
    overlay
        Overlay mode: ``"auto"``, ``"off"``, or ``"on"``.
    overlay_annotate
        Annotation mode for overlay profiles: ``"none"``, ``"energy"``, or
        ``"full"``.
    overlay_alpha, overlay_lw_scale
        Alpha and line-width scaling for overlay profiles.
    marker, overlay_markers
        Marker style for reference and overlay points.
    show_legend
        Whether to draw a legend.
    profile_label
        Legend label for a single profile.
    overlay_colors
        Optional overlay color mapping or sequence. A two-item tuple sets
        ``(main_color, side_color)``.
    same_energy_tol, same_energy_mode, same_energy_tag
        Controls for suppressing or tagging matching overlay energies.
    show_state_labels, state_label_rotation, state_label_pad
        X-axis state-label controls.
    font_size, state_label_fontsize, energy_fontsize, axis_label_fontsize,
    tick_label_fontsize, legend_fontsize, same_energy_tag_fontsize
        Font-size controls. Specific values override `font_size`.

    Returns
    -------
    tuple
        ``(fig, ax)`` for the Matplotlib figure and axes.
    """
    base_fontsize = 12.0 if font_size is None else float(font_size)
    state_label_fontsize = (
        base_fontsize
        if state_label_fontsize is None
        else float(state_label_fontsize)
    )
    energy_fontsize = (
        base_fontsize
        if energy_fontsize is None
        else float(energy_fontsize)
    )
    axis_label_fontsize = (
        base_fontsize
        if axis_label_fontsize is None
        else float(axis_label_fontsize)
    )
    tick_label_fontsize = (
        base_fontsize
        if tick_label_fontsize is None
        else float(tick_label_fontsize)
    )
    legend_fontsize = (
        base_fontsize
        if legend_fontsize is None
        else float(legend_fontsize)
    )
    same_energy_tag_fontsize = (
        energy_fontsize
        if same_energy_tag_fontsize is None
        else float(same_energy_tag_fontsize)
    )













    def _plot_one(
        profile_name,
        profile_states,
        ax_,
        is_reference,
        ref_x_map,
        ref_prod_xs,
        ref_energy_map,
        overlay_idx,
    ):
        (
            entries,
            seg_ids,
            side_anchor_label,
            side_connector_rise_frac,
            side_legend_label,
        ) = _parse_entries(
            profile_states
        )

        names = [e[0] for e in entries]
        E = np.array([e[1] for e in entries], dtype=float)

        if is_reference or not ref_x_map:
            x = _compute_x_single(entries, product_x_offset)
        else:
            x = _compute_x_from_reference(
                entries,
                ref_x_map,
                ref_prod_xs,
                product_x_offset,
            )

        profile_energy_map = _build_energy_map(entries)

        product_indices = [i for i, lab in enumerate(names) if _is_product(lab)]
        main_product_idx = product_indices[0] if product_indices else (len(entries) - 1)
        side_product_idx = (
            product_indices[1] if len(product_indices) >= 2 else main_product_idx
        )

        side_start_idx = None
        for i, sid in enumerate(seg_ids):
            if sid == 1:
                side_start_idx = i
                break

        main_color, side_color = _resolve_colors(
            overlay_colors,
            profile_name,
            is_reference,
            overlay_idx,
            side_start_idx is not None,
        )
        point_colors = [main_color] * len(entries)
        a = 1.0 if is_reference else float(overlay_alpha)
        lw = (1.5 * float(overlay_lw_scale)) if not is_reference else 1.5
        z_line = 5 if is_reference else 3
        z_scatter = 6 if is_reference else 4
        z_conn = 2.5 if is_reference else 2.0
        legend_marker = marker if is_reference else (
            overlay_markers.get(profile_name, marker)
            if isinstance(overlay_markers, dict)
            else marker
        )
        side_legend_meta = (
            {
                "profile_name": None if profile_name is None else str(profile_name),
                "label": str(side_legend_label) if side_legend_label is not None else None,
                "color": side_color,
                "alpha": a,
                "marker": legend_marker,
            }
            if side_start_idx is not None
            else None
        )

        if side_start_idx is None:
            x_i, E_i = _dedup_for_interp(x, E)
            xs = np.linspace(x_i.min(), x_i.max(), int(n_points))
            interp = PchipInterpolator(x_i, E_i)
            Es = interp(xs)

            ax_.plot(
                xs,
                Es,
                marker="",
                alpha=a,
                linewidth=lw,
                color=main_color,
                zorder=z_line
            )
            m = marker if is_reference else (
                overlay_markers.get(profile_name, marker)
                if isinstance(overlay_markers, dict)
                else marker
            )
            ax_.scatter(
                x,
                E,
                zorder=z_scatter,
                color=main_color,
                alpha=a,
                marker=m,
                s=30,
            )
        else:
            if side_start_idx == 0:
                raise ValueError(f"{side_token!r} cannot be the first entry.")

            main_end = side_start_idx - 1

            x_main = x[: main_end + 1]
            E_main = E[: main_end + 1]
            x_main_i, E_main_i = _dedup_for_interp(x_main, E_main)

            xs_main = np.linspace(
                x_main_i.min(), x_main_i.max(), max(2, int(n_points * 0.6))
            )
            interp_main = PchipInterpolator(x_main_i, E_main_i)
            Es_main = interp_main(xs_main)

            ax_.plot(
                xs_main,
                Es_main,
                marker="",
                alpha=a,
                linewidth=lw,
                color=main_color,
                zorder=z_line
            )
            m = marker if is_reference else (
                overlay_markers.get(profile_name, marker)
                if isinstance(overlay_markers, dict)
                else marker
            )            
            ax_.scatter(
                x_main,
                E_main,
                zorder=z_scatter,
                color=main_color,
                alpha=a,
                marker=m,
                s=30,
            )

            side_anchor_idx = main_end
            if side_anchor_label is not None:
                target = side_anchor_label.lower().strip()
                for j, (lab, _, _) in enumerate(entries):
                    if _norm_label(lab) == target:
                        side_anchor_idx = j
                        break
                else:
                    raise ValueError(
                        f"side-rxn anchor {side_anchor_label!r} not found among labels."
                    )

            if side_anchor_idx >= side_start_idx:
                raise ValueError(
                    f"side-rxn anchor {side_anchor_label!r} must be before side segment."
                )

            side_idxs = [i for i in range(side_start_idx, len(entries))]
            if main_product_idx in side_idxs and main_product_idx != side_product_idx:
                side_idxs = [i for i in side_idxs if i != main_product_idx]

            if side_product_idx not in side_idxs and side_product_idx >= side_start_idx:
                side_idxs.append(side_product_idx)
                side_idxs = sorted(set(side_idxs))

            for idx in side_idxs:
                point_colors[idx] = side_color
            point_colors[main_product_idx] = main_color

            x_side_main = x[side_idxs]
            E_side_main = E[side_idxs]
            x_side_i, E_side_i = _dedup_for_interp(x_side_main, E_side_main)

            xs_side = np.linspace(
                float(x_side_i.min()),
                float(x_side_i.max()),
                max(2, int(n_points * 0.6)),
            )
            interp_side = PchipInterpolator(x_side_i, E_side_i)
            Es_side = interp_side(xs_side)

            ax_.plot(
                xs_side,
                Es_side,
                marker="",
                alpha=a,
                linewidth=lw,
                color=side_color,
                zorder=z_line
            )
            m = marker if is_reference else (
                overlay_markers.get(profile_name, marker)
                if isinstance(overlay_markers, dict)
                else marker
            )            
            ax_.scatter(
                x_side_main,
                E_side_main,
                zorder=z_scatter,
                color=side_color,
                alpha=a,
                marker=m,
                s=30,
            )

            x0 = float(x[side_anchor_idx])
            y0 = float(E[side_anchor_idx])
            x1c = float(x[side_start_idx])
            y1c = float(E[side_start_idx])

            frac = (
                0.0
                if side_connector_rise_frac is None
                else float(side_connector_rise_frac)
            )
            frac = min(max(frac, 0.0), 1.0)

            x_rise = x0 + frac * (x1c - x0)

            xs_flat = np.linspace(x0, x_rise, 60, endpoint=False)
            ys_flat = np.full_like(xs_flat, y0, dtype=float)

            xs_rise = np.linspace(x_rise, x1c, 120)
            denom = (x1c - x_rise)
            if denom == 0:
                ys_rise = np.full_like(xs_rise, y1c, dtype=float)
            else:
                t = (xs_rise - x_rise) / denom
                t = np.clip(t, 0.0, 1.0)
                s = t * t * (3.0 - 2.0 * t)
                ys_rise = y0 + (y1c - y0) * s

            xs_conn = np.concatenate([xs_flat, xs_rise])
            ys_conn = np.concatenate([ys_flat, ys_rise])

            ax_.plot(
                xs_conn,
                ys_conn,
                linestyle=":",
                linewidth=3.0,
                alpha=a,
                marker="",
                color=side_color,
                zorder=z_conn
            )

            if show_main_to_product and len(x) >= 2:
                x0u = float(x[main_end])
                y0u = float(E[main_end])
                x1u = float(x[main_product_idx])
                y1u = float(E[main_product_idx])

                frac = min(max(float(main_to_product_drop_frac), 0.0), 1.0)
                x_drop = x0u + frac * (x1u - x0u)

                n_flat = (
                    int(main_to_product_flat_points)
                    if main_to_product_flat_points is not None
                    else max(20, int(n_points * 0.15))
                )
                n_drop = (
                    int(main_to_product_drop_points)
                    if main_to_product_drop_points is not None
                    else max(80, int(n_points * 0.35))
                )

                xs_flat = np.linspace(x0u, x_drop, max(2, n_flat), endpoint=False)
                ys_flat = np.full_like(xs_flat, y0u, dtype=float)

                xs_drop = np.linspace(x_drop, x1u, max(2, n_drop))
                denom = (x1u - x_drop)
                if denom == 0:
                    ys_drop = np.full_like(xs_drop, y1u, dtype=float)
                else:
                    t = (xs_drop - x_drop) / denom
                    t = np.clip(t, 0.0, 1.0)
                    s = t * t * (3.0 - 2.0 * t)
                    ys_drop = y0u + (y1u - y0u) * s

                xs_usual = np.concatenate([xs_flat, xs_drop])
                ys_usual = np.concatenate([ys_flat, ys_drop])

                mp_color = "C0" if is_reference else main_color

                ax_.plot(
                    xs_usual,
                    ys_usual,
                    linestyle=main_to_product_linestyle,
                    linewidth=main_to_product_lw,
                    alpha=main_to_product_alpha * a,
                    marker="",
                    color=mp_color,
                    zorder=z_conn
                )
                m = marker if is_reference else (
                    overlay_markers.get(profile_name, marker)
                    if isinstance(overlay_markers, dict)
                    else marker
                )                
                ax_.scatter(
                    [x1u],
                    [y1u],
                    zorder=z_scatter,
                    color=mp_color,
                    alpha=a,
                    marker=m,
                    s=30,
                )

        # --- Energy annotations (labels are handled on x-axis if enabled) ---
        # --- Annotations ---
        if is_reference:
            do_annotate = bool(annotate_energies)
        else:
            do_annotate = overlay_annotate in {"energy", "full"}

        if do_annotate:
            for i, (xi, Ei, label) in enumerate(zip(x, E, names), start=1):
                key = _norm_label(label)
                is_dummy = dummy_substr.lower() in key

                # keep your "same energy" suppression for overlays
                if not is_reference and ref_energy_map is not None:
                    ref_e = ref_energy_map.get(key)
                    if ref_e is not None and abs(float(Ei) - float(ref_e)) <= float(
                        same_energy_tol
                    ):
                        continue

                placement_counts = _parse_placement(entries[i - 1][2])
                if placement_counts is None:
                    is_int = key.startswith(int_prefix.lower())
                    if i == 1:
                        placement_counts = {"left": 1, "right": 0, "top": 0, "bottom": 0}
                    elif i == len(entries):
                        placement_counts = {"right": 1, "left": 0, "top": 0, "bottom": 0}
                    elif is_int:
                        placement_counts = {"bottom": 1, "top": 0, "left": 0, "right": 0}
                    else:
                        placement_counts = {"top": 1, "bottom": 0, "left": 0, "right": 0}

                txt_color = point_colors[i - 1]
                alpha = 1.0 if is_reference else float(overlay_alpha)

                if not multi:
                    # SINGLE-MOLECULE: restore original label+energy annotations
                    top_n = placement_counts["top"]
                    bottom_n = placement_counts["bottom"]
                    left_n = placement_counts["left"]
                    right_n = placement_counts["right"]

                    dx = 0
                    dy = 0
                    ha = "center"
                    va = "center"

                    if left_n:
                        dx = -12 * left_n
                        ha = "right"
                    elif right_n:
                        dx = 12 * right_n
                        ha = "left"

                    if top_n:
                        dy = abs(label_offset_up) * top_n
                        va = "bottom"
                    elif bottom_n:
                        dy = -abs(label_offset_down) * bottom_n
                        va = "top"

                    add_arrow = max(top_n, bottom_n, left_n, right_n) > 1

                    if annotate_energies:
                        text = f"{label}\n{Ei:.{decimals}f}"
                    else:
                        text = f"{label}"

                    a = (dummy_alpha if is_dummy else 1.0) * alpha

                    arrowprops = None
                    if add_arrow:
                        arrowprops = {
                            "arrowstyle": "->",
                            "lw": 0.8,
                            "alpha": a * 0.8,
                            "shrinkA": 0,
                            "shrinkB": 6,
                            "mutation_scale": 8,
                        }

                    ax_.annotate(
                        text,
                        (xi, Ei),
                        textcoords="offset points",
                        xytext=(dx, dy),
                        ha=ha,
                        va=va,
                        alpha=a,
                        arrowprops=arrowprops,
                        color=txt_color,
                        fontsize=energy_fontsize,
                    )
                else:
                    # MULTI-MOLECULE: energy-only (state names are on x-axis)
                    _annotate_energy_only(
                        ax_=ax_,
                        xi=float(xi),
                        Ei=float(Ei),
                        alpha=alpha,
                        color=txt_color,
                        placement_counts=placement_counts,
                        is_dummy=is_dummy,
                        decimals=decimals,
                        label_offset_up=label_offset_up,
                        label_offset_down=label_offset_down,
                        dummy_alpha=dummy_alpha,
                        energy_fontsize=energy_fontsize,
                    )

        if is_reference:
            x_map = {}
            prod_xs = []
            ordered = []
            for xi, lab in zip(x, names):
                k = _norm_label(lab)
                x_map[k] = float(xi)
                ordered.append((float(xi), str(lab)))
                if _is_product(lab):
                    prod_xs.append(float(xi))
            return x_map, prod_xs, profile_energy_map, ordered, side_legend_meta

        return None, None, profile_energy_map, None, side_legend_meta

    # ---- Detect multi-profile input (no breaking of current list input) ----
    multi = False
    profiles = None

    if isinstance(states, dict):
        profiles = list(states.items())
        multi = True
    elif isinstance(states, (list, tuple)) and states:
        first = states[0]
        if (
            isinstance(first, (list, tuple))
            and len(first) == 2
            and isinstance(first[0], str)
            and isinstance(first[1], (list, tuple))
        ):
            profiles = list(states)
            multi = True

    if overlay == "off":
        multi = False
        profiles = None
    elif overlay == "on":
        if not multi:
            raise ValueError("overlay='on' requires dict or list-of-(name, states).")

    if show_state_labels is None:
        show_state_labels = bool(multi)

    created_fig = False
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
        created_fig = True
    else:
        fig = ax.figure

    ref_x_map = {}
    ref_prod_xs = []
    ref_energy_map = None
    ref_ordered = None
    side_legend_metas = []

    if not multi:
        _, _, _, _, side_meta = _plot_one(
            profile_name=None,
            profile_states=states,
            ax_=ax,
            is_reference=True,
            ref_x_map=ref_x_map,
            ref_prod_xs=ref_prod_xs,
            ref_energy_map=None,
            overlay_idx=0,
        )
        if side_meta is not None:
            side_legend_metas.append(side_meta)
    else:
        ref_name, ref_states = profiles[0]
        ref_x_map, ref_prod_xs, ref_energy_map, ref_ordered, side_meta = _plot_one(
            profile_name=ref_name,
            profile_states=ref_states,
            ax_=ax,
            is_reference=True,
            ref_x_map=ref_x_map,
            ref_prod_xs=ref_prod_xs,
            ref_energy_map=None,
            overlay_idx=0,
        )
        if side_meta is not None:
            side_legend_metas.append(side_meta)

        overlay_energy_maps = []
        for k, (name, st) in enumerate(profiles[1:], start=0):
            _, _, e_map, _, side_meta = _plot_one(
                profile_name=name,
                profile_states=st,
                ax_=ax,
                is_reference=False,
                ref_x_map=ref_x_map,
                ref_prod_xs=ref_prod_xs,
                ref_energy_map=ref_energy_map,
                overlay_idx=k,
            )
            overlay_energy_maps.append(e_map)
            if side_meta is not None:
                side_legend_metas.append(side_meta)

        if same_energy_mode == "tag" and annotate_energies and ref_energy_map is not None:
            for key, ref_e in ref_energy_map.items():
                matched = False
                for om in overlay_energy_maps:
                    oe = om.get(key) if om is not None else None
                    if oe is None:
                        continue
                    if abs(float(oe) - float(ref_e)) <= float(same_energy_tol):
                        matched = True
                        break
                if not matched:
                    continue

                xi = float(ref_x_map[key])
                yi = float(ref_e)
                ax.annotate(
                    same_energy_tag,
                    (xi, yi),
                    textcoords="offset points",
                    xytext=(8, 0),
                    ha="left",
                    va="center",
                    alpha=1.0,
                    color="C0",
                    fontsize=same_energy_tag_fontsize,
                )

        if show_legend:
            handles = []
            labels = []
            for i, (name, _) in enumerate(profiles):
                if name is None:
                    continue
                if i == 0:
                    color = "C0"
                    a = 1.0
                else:
                    if isinstance(overlay_colors, dict) and name in overlay_colors:
                        spec = overlay_colors[name]
                        color = spec[0] if isinstance(spec, (tuple, list)) else spec
                    else:
                        color = f"C{i}"
                    a = overlay_alpha
                if i == 0:
                    m = marker
                else:
                    if isinstance(overlay_markers, dict):
                        m = overlay_markers.get(name, marker)
                    else:
                        m = marker

                h = plt.Line2D(
                    [0],
                    [0],
                    color=color,
                    alpha=a,
                    marker=m,
                    linestyle="-",
                )
                handles.append(h)
                labels.append(str(name))
            for meta in side_legend_metas:
                label = meta["label"]
                if label is None:
                    continue

                style_meta = _style_meta_for_side_label(
                    label,
                    meta,
                    side_legend_metas,
                )
                h = plt.Line2D(
                    [0],
                    [0],
                    color=style_meta["color"],
                    alpha=style_meta["alpha"],
                    marker=style_meta["marker"],
                    linestyle="-",
                )
                handles.append(h)
                labels.append(label)
            if handles:
                ax.legend(handles, labels, frameon=False, fontsize=legend_fontsize)

    if not multi and show_legend:
        handles = []
        labels = []

        if profile_label is not None:
            h = plt.Line2D(
                [0],
                [0],
                color="C0",
                alpha=1.0,
                marker=marker,
                linestyle="-",
            )
            handles.append(h)
            labels.append(str(profile_label))

        for meta in side_legend_metas:
            label = meta["label"]
            if label is None:
                continue

            h = plt.Line2D(
                [0],
                [0],
                color=meta["color"],
                alpha=meta["alpha"],
                marker=meta["marker"],
                linestyle="-",
            )
            handles.append(h)
            labels.append(label)

        if handles:
            ax.legend(handles, labels, frameon=False, fontsize=legend_fontsize)

    ax.set_ylabel(ylabel, fontsize=axis_label_fontsize)

    # --- Bottom labels (states) ---
    if show_state_labels:
        # Use reference ordering if available (multi); otherwise derive from single.
        if ref_ordered is None:
            entries, _, _, _, _ = _parse_entries(states)
            x_single = _compute_x_single(entries, product_x_offset)
            ref_ordered = [(float(xi), str(lab)) for xi, (lab, _, _) in zip(x_single, entries)]

        # If there are duplicated x (multiple products), matplotlib will still
        # accept them; labels may overlap, but the offsets typically separate them.
        xs = [p[0] for p in ref_ordered]
        labs = [p[1] for p in ref_ordered]

        ax.set_xticks(xs)
        ax.set_xticklabels(labs, rotation=state_label_rotation,
                           fontsize=state_label_fontsize)
        ax.tick_params(axis="x", pad=state_label_pad)

        hide_x_ticks = False

    # --- Limits ---
    if ref_x_map:
        xmin = min(ref_x_map.values())
        xmax = max(ref_x_map.values())
    else:
        xmin, xmax = ax.get_xlim()

    left_pad = 1.05
    right_pad = 0.8
    ax.set_xlim(xmin - left_pad, xmax + right_pad)

    ax.grid(bool(grid))
    ax.set_facecolor("white")

    if hide_x_ticks:
        ax.set_xticks([])
    elif not show_state_labels:
        ax.tick_params(axis="x", labelsize=tick_label_fontsize)

    if hide_y_ticks:
        ax.set_yticks([])
    else:
        ax.tick_params(axis="y", labelsize=tick_label_fontsize)

    if hide_spines:
        for spine in ax.spines.values():
            spine.set_visible(False)

    if created_fig:
        fig.tight_layout()

    return fig, ax