@@ -109,6 +109,28 @@ def _normalize_attr(
109
109
return _normalize_scale (attr_combined , threshold )
110
110
111
111
112
+ def _create_default_plot (
113
+ # pyre-fixme[2]: Parameter must be annotated.
114
+ plt_fig_axis ,
115
+ # pyre-fixme[2]: Parameter must be annotated.
116
+ use_pyplot ,
117
+ # pyre-fixme[2]: Parameter must be annotated.
118
+ fig_size ,
119
+ ** pyplot_kwargs : Any ,
120
+ ) -> Tuple [Figure , Axes ]:
121
+ # Create plot if figure, axis not provided
122
+ if plt_fig_axis is not None :
123
+ plt_fig , plt_axis = plt_fig_axis
124
+ else :
125
+ if use_pyplot :
126
+ plt_fig , plt_axis = plt .subplots (figsize = fig_size , ** pyplot_kwargs )
127
+ else :
128
+ plt_fig = Figure (figsize = fig_size )
129
+ plt_axis = plt_fig .subplots (** pyplot_kwargs )
130
+ return plt_fig , plt_axis
131
+ # Figure.subplots returns Axes or array of Axes
132
+
133
+
112
134
def _initialize_cmap_and_vmin_vmax (
113
135
sign : str ,
114
136
) -> Tuple [Union [str , Colormap ], float , float ]:
@@ -338,16 +360,7 @@ def visualize_image_attr(
338
360
>>> # Displays blended heat map visualization of computed attributions.
339
361
>>> _ = visualize_image_attr(attribution, orig_image, "blended_heat_map")
340
362
"""
341
- # Create plot if figure, axis not provided
342
- if plt_fig_axis is not None :
343
- plt_fig , plt_axis = plt_fig_axis
344
- else :
345
- if use_pyplot :
346
- plt_fig , plt_axis = plt .subplots (figsize = fig_size )
347
- else :
348
- plt_fig = Figure (figsize = fig_size )
349
- plt_axis = plt_fig .subplots ()
350
- # Figure.subplots returns Axes or array of Axes
363
+ plt_fig , plt_axis = _create_default_plot (plt_fig_axis , use_pyplot , fig_size )
351
364
352
365
if original_image is not None :
353
366
if np .max (original_image ) <= 1.0 :
@@ -362,8 +375,10 @@ def visualize_image_attr(
362
375
)
363
376
364
377
# Remove ticks and tick labels from plot.
365
- plt_axis .xaxis .set_ticks_position ("none" )
366
- plt_axis .yaxis .set_ticks_position ("none" )
378
+ if plt_axis .xaxis is not None :
379
+ plt_axis .xaxis .set_ticks_position ("none" )
380
+ if plt_axis .yaxis is not None :
381
+ plt_axis .yaxis .set_ticks_position ("none" )
367
382
plt_axis .set_yticklabels ([])
368
383
plt_axis .set_xticklabels ([])
369
384
plt_axis .grid (visible = False )
@@ -528,6 +543,161 @@ def visualize_image_attr_multiple(
528
543
return plt_fig , plt_axis
529
544
530
545
546
+ def _plot_attrs_as_axvspan (
547
+ # pyre-fixme[2]: Parameter must be annotated.
548
+ attr_vals ,
549
+ # pyre-fixme[2]: Parameter must be annotated.
550
+ x_vals ,
551
+ # pyre-fixme[2]: Parameter must be annotated.
552
+ ax ,
553
+ # pyre-fixme[2]: Parameter must be annotated.
554
+ x_values ,
555
+ # pyre-fixme[2]: Parameter must be annotated.
556
+ cmap ,
557
+ # pyre-fixme[2]: Parameter must be annotated.
558
+ cm_norm ,
559
+ # pyre-fixme[2]: Parameter must be annotated.
560
+ alpha_overlay ,
561
+ ) -> None :
562
+ # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
563
+ half_col_width = (x_values [1 ] - x_values [0 ]) / 2.0
564
+
565
+ for icol , col_center in enumerate (x_vals ):
566
+ left = col_center - half_col_width
567
+ right = col_center + half_col_width
568
+ ax .axvspan (
569
+ xmin = left ,
570
+ xmax = right ,
571
+ # pyre-fixme[29]: `Union[None, Colormap, str]` is not a function.
572
+ facecolor = (cmap (cm_norm (attr_vals [icol ]))), # type: ignore
573
+ edgecolor = None ,
574
+ alpha = alpha_overlay ,
575
+ )
576
+
577
+
578
+ def _visualize_overlay_individual (
579
+ # pyre-fixme[2]: Parameter must be annotated.
580
+ num_channels ,
581
+ # pyre-fixme[2]: Parameter must be annotated.
582
+ plt_axis_list ,
583
+ # pyre-fixme[2]: Parameter must be annotated.
584
+ x_values ,
585
+ # pyre-fixme[2]: Parameter must be annotated.
586
+ data ,
587
+ # pyre-fixme[2]: Parameter must be annotated.
588
+ channel_labels ,
589
+ # pyre-fixme[2]: Parameter must be annotated.
590
+ norm_attr ,
591
+ # pyre-fixme[2]: Parameter must be annotated.
592
+ cmap ,
593
+ # pyre-fixme[2]: Parameter must be annotated.
594
+ cm_norm ,
595
+ # pyre-fixme[2]: Parameter must be annotated.
596
+ alpha_overlay ,
597
+ # pyre-fixme[2]: Parameter must be annotated.
598
+ ** kwargs : Any ,
599
+ ) -> None :
600
+ # helper method for visualize_timeseries_attr
601
+ pyplot_kwargs = kwargs .get ("pyplot_kwargs" , {})
602
+ for chan in range (num_channels ):
603
+ plt_axis_list [chan ].plot (x_values , data [chan , :], ** pyplot_kwargs )
604
+ if channel_labels is not None :
605
+ plt_axis_list [chan ].set_ylabel (channel_labels [chan ])
606
+
607
+ _plot_attrs_as_axvspan (
608
+ norm_attr [chan ],
609
+ x_values ,
610
+ plt_axis_list [chan ],
611
+ x_values ,
612
+ cmap ,
613
+ cm_norm ,
614
+ alpha_overlay ,
615
+ )
616
+
617
+ plt .subplots_adjust (hspace = 0 )
618
+ pass
619
+
620
+
621
+ def _visualize_overlay_combined (
622
+ # pyre-fixme[2]: Parameter must be annotated.
623
+ num_channels ,
624
+ # pyre-fixme[2]: Parameter must be annotated.
625
+ plt_axis_list ,
626
+ # pyre-fixme[2]: Parameter must be annotated.
627
+ x_values ,
628
+ # pyre-fixme[2]: Parameter must be annotated.
629
+ data ,
630
+ # pyre-fixme[2]: Parameter must be annotated.
631
+ channel_labels ,
632
+ # pyre-fixme[2]: Parameter must be annotated.
633
+ norm_attr ,
634
+ # pyre-fixme[2]: Parameter must be annotated.
635
+ cmap ,
636
+ # pyre-fixme[2]: Parameter must be annotated.
637
+ cm_norm ,
638
+ # pyre-fixme[2]: Parameter must be annotated.
639
+ alpha_overlay ,
640
+ ** kwargs : Any ,
641
+ ) -> None :
642
+ pyplot_kwargs = kwargs .get ("pyplot_kwargs" , {})
643
+
644
+ cycler = plt .cycler ("color" , matplotlib .colormaps ["Dark2" ].colors ) # type: ignore
645
+ plt_axis_list [0 ].set_prop_cycle (cycler )
646
+
647
+ for chan in range (num_channels ):
648
+ label = channel_labels [chan ] if channel_labels else None
649
+ plt_axis_list [0 ].plot (x_values , data [chan , :], label = label , ** pyplot_kwargs )
650
+
651
+ _plot_attrs_as_axvspan (
652
+ norm_attr ,
653
+ x_values ,
654
+ plt_axis_list [0 ],
655
+ x_values ,
656
+ cmap ,
657
+ cm_norm ,
658
+ alpha_overlay ,
659
+ )
660
+
661
+ plt_axis_list [0 ].legend (loc = "best" )
662
+
663
+
664
+ def _visualize_colored_graph (
665
+ # pyre-fixme[2]: Parameter must be annotated.
666
+ num_channels ,
667
+ # pyre-fixme[2]: Parameter must be annotated.
668
+ plt_axis_list ,
669
+ # pyre-fixme[2]: Parameter must be annotated.
670
+ x_values ,
671
+ # pyre-fixme[2]: Parameter must be annotated.
672
+ data ,
673
+ # pyre-fixme[2]: Parameter must be annotated.
674
+ channel_labels ,
675
+ # pyre-fixme[2]: Parameter must be annotated.
676
+ norm_attr ,
677
+ # pyre-fixme[2]: Parameter must be annotated.
678
+ cmap ,
679
+ # pyre-fixme[2]: Parameter must be annotated.
680
+ cm_norm ,
681
+ ** kwargs : Any ,
682
+ ) -> None :
683
+ # helper method for visualize_timeseries_attr
684
+ pyplot_kwargs = kwargs .get ("pyplot_kwargs" , {})
685
+ for chan in range (num_channels ):
686
+ points = np .array ([x_values , data [chan , :]]).T .reshape (- 1 , 1 , 2 )
687
+ segments = np .concatenate ([points [:- 1 ], points [1 :]], axis = 1 )
688
+
689
+ lc = LineCollection (segments , cmap = cmap , norm = cm_norm , ** pyplot_kwargs )
690
+ lc .set_array (norm_attr [chan , :])
691
+ plt_axis_list [chan ].add_collection (lc )
692
+ plt_axis_list [chan ].set_ylim (
693
+ 1.2 * np .min (data [chan , :]), 1.2 * np .max (data [chan , :])
694
+ )
695
+ if channel_labels is not None :
696
+ plt_axis_list [chan ].set_ylabel (channel_labels [chan ])
697
+
698
+ plt .subplots_adjust (hspace = 0 )
699
+
700
+
531
701
def visualize_timeseries_attr (
532
702
attr : npt .NDArray ,
533
703
data : npt .NDArray ,
@@ -686,8 +856,8 @@ def visualize_timeseries_attr(
686
856
687
857
num_subplots = num_channels
688
858
if (
689
- TimeseriesVisualizationMethod [method ]
690
- == TimeseriesVisualizationMethod .overlay_combined
859
+ TimeseriesVisualizationMethod [method ]. value
860
+ == TimeseriesVisualizationMethod .overlay_combined . value
691
861
):
692
862
num_subplots = 1
693
863
attr = np .sum (attr , axis = 0 ) # Merge attributions across channels
@@ -700,17 +870,9 @@ def visualize_timeseries_attr(
700
870
x_values = np .arange (timeseries_length )
701
871
702
872
# Create plot if figure, axis not provided
703
- if plt_fig_axis is not None :
704
- plt_fig , plt_axis = plt_fig_axis
705
- else :
706
- if use_pyplot :
707
- plt_fig , plt_axis = plt .subplots ( # type: ignore
708
- figsize = fig_size , nrows = num_subplots , sharex = True
709
- )
710
- else :
711
- plt_fig = Figure (figsize = fig_size )
712
- plt_axis = plt_fig .subplots (nrows = num_subplots , sharex = True ) # type: ignore
713
- # Figure.subplots returns Axes or array of Axes
873
+ plt_fig , plt_axis = _create_default_plot (
874
+ plt_fig_axis , use_pyplot , fig_size , nrows = num_subplots , sharex = True
875
+ )
714
876
715
877
if not isinstance (plt_axis , ndarray ):
716
878
plt_axis_list = np .array ([plt_axis ])
@@ -720,91 +882,30 @@ def visualize_timeseries_attr(
720
882
norm_attr = _normalize_attr (attr , sign , outlier_perc , reduction_axis = None )
721
883
722
884
# Set default colormap and bounds based on sign.
723
- if VisualizeSign [sign ] == VisualizeSign .all :
724
- default_cmap : Union [str , LinearSegmentedColormap ] = (
725
- LinearSegmentedColormap .from_list ("RdWhGn" , ["red" , "white" , "green" ])
726
- )
727
- vmin , vmax = - 1 , 1
728
- elif VisualizeSign [sign ] == VisualizeSign .positive :
729
- default_cmap = "Greens"
730
- vmin , vmax = 0 , 1
731
- elif VisualizeSign [sign ] == VisualizeSign .negative :
732
- default_cmap = "Reds"
733
- vmin , vmax = 0 , 1
734
- elif VisualizeSign [sign ] == VisualizeSign .absolute_value :
735
- default_cmap = "Blues"
736
- vmin , vmax = 0 , 1
737
- else :
738
- raise AssertionError ("Visualize Sign type is not valid." )
885
+ default_cmap , vmin , vmax = _initialize_cmap_and_vmin_vmax (sign )
739
886
cmap = cmap if cmap is not None else default_cmap
740
887
cmap = cm .get_cmap (cmap ) # type: ignore
741
888
cm_norm = colors .Normalize (vmin , vmax )
742
889
743
- # pyre-fixme[53]: Captured variable `cm_norm` is not annotated.
744
- # pyre-fixme[2]: Parameter must be annotated.
745
- def _plot_attrs_as_axvspan (attr_vals , x_vals , ax ) -> None :
746
- # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
747
- half_col_width = (x_values [1 ] - x_values [0 ]) / 2.0
748
- for icol , col_center in enumerate (x_vals ):
749
- left = col_center - half_col_width
750
- right = col_center + half_col_width
751
- ax .axvspan (
752
- xmin = left ,
753
- xmax = right ,
754
- # pyre-fixme[29]: `Union[None, Colormap, str]` is not a function.
755
- facecolor = (cmap (cm_norm (attr_vals [icol ]))), # type: ignore
756
- edgecolor = None ,
757
- alpha = alpha_overlay ,
758
- )
759
-
760
- if (
761
- TimeseriesVisualizationMethod [method ]
762
- == TimeseriesVisualizationMethod .overlay_individual
763
- ):
764
- for chan in range (num_channels ):
765
- plt_axis_list [chan ].plot (x_values , data [chan , :], ** pyplot_kwargs )
766
- if channel_labels is not None :
767
- plt_axis_list [chan ].set_ylabel (channel_labels [chan ])
768
-
769
- _plot_attrs_as_axvspan (norm_attr [chan ], x_values , plt_axis_list [chan ])
770
-
771
- plt .subplots_adjust (hspace = 0 )
772
-
773
- elif (
774
- TimeseriesVisualizationMethod [method ]
775
- == TimeseriesVisualizationMethod .overlay_combined
776
- ):
777
- # Dark colors are better in this case
778
- cycler = plt .cycler ("color" , matplotlib .colormaps ["Dark2" ]) # type: ignore
779
- plt_axis_list [0 ].set_prop_cycle (cycler )
780
-
781
- for chan in range (num_channels ):
782
- label = channel_labels [chan ] if channel_labels else None
783
- plt_axis_list [0 ].plot (x_values , data [chan , :], label = label , ** pyplot_kwargs )
784
-
785
- _plot_attrs_as_axvspan (norm_attr , x_values , plt_axis_list [0 ])
786
-
787
- plt_axis_list [0 ].legend (loc = "best" )
788
-
789
- elif (
790
- TimeseriesVisualizationMethod [method ]
791
- == TimeseriesVisualizationMethod .colored_graph
792
- ):
793
- for chan in range (num_channels ):
794
- points = np .array ([x_values , data [chan , :]]).T .reshape (- 1 , 1 , 2 )
795
- segments = np .concatenate ([points [:- 1 ], points [1 :]], axis = 1 )
796
-
797
- lc = LineCollection (segments , cmap = cmap , norm = cm_norm , ** pyplot_kwargs )
798
- lc .set_array (norm_attr [chan , :])
799
- plt_axis_list [chan ].add_collection (lc )
800
- plt_axis_list [chan ].set_ylim (
801
- 1.2 * np .min (data [chan , :]), 1.2 * np .max (data [chan , :])
802
- )
803
- if channel_labels is not None :
804
- plt_axis_list [chan ].set_ylabel (channel_labels [chan ])
805
-
806
- plt .subplots_adjust (hspace = 0 )
807
-
890
+ visualization_methods : Dict [str , Callable [..., Union [None , AxesImage ]]] = {
891
+ "overlay_individual" : _visualize_overlay_individual ,
892
+ "overlay_combined" : _visualize_overlay_combined ,
893
+ "colored_graph" : _visualize_colored_graph ,
894
+ }
895
+ kwargs = {
896
+ "num_channels" : num_channels ,
897
+ "plt_axis_list" : plt_axis_list ,
898
+ "x_values" : x_values ,
899
+ "data" : data ,
900
+ "channel_labels" : channel_labels ,
901
+ "norm_attr" : norm_attr ,
902
+ "cmap" : cmap ,
903
+ "cm_norm" : cm_norm ,
904
+ "alpha_overlay" : alpha_overlay ,
905
+ "pyplot_kwargs" : pyplot_kwargs ,
906
+ }
907
+ if method in visualization_methods :
908
+ visualization_methods [method ](** kwargs )
808
909
else :
809
910
raise AssertionError ("Invalid visualization method: {}" .format (method ))
810
911
0 commit comments