Skip to content

Commit f61847f

Browse files
jjunchofacebook-github-bot
authored andcommitted
'visualize_timeseries_attr' is too complex (#1384)
Summary: Pull Request resolved: #1384 This diff addresses the C901 in visualization.py by breaking down the method Reviewed By: vivekmig Differential Revision: D64513163 fbshipit-source-id: a7d3b658b41255124163b914b5e9a87fc424bd0e
1 parent 1c941a6 commit f61847f

File tree

1 file changed

+207
-106
lines changed

1 file changed

+207
-106
lines changed

captum/attr/_utils/visualization.py

Lines changed: 207 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,28 @@ def _normalize_attr(
109109
return _normalize_scale(attr_combined, threshold)
110110

111111

112+
def _create_default_plot(
113+
# pyre-fixme[2]: Parameter must be annotated.
114+
plt_fig_axis,
115+
# pyre-fixme[2]: Parameter must be annotated.
116+
use_pyplot,
117+
# pyre-fixme[2]: Parameter must be annotated.
118+
fig_size,
119+
**pyplot_kwargs: Any,
120+
) -> Tuple[Figure, Axes]:
121+
# Create plot if figure, axis not provided
122+
if plt_fig_axis is not None:
123+
plt_fig, plt_axis = plt_fig_axis
124+
else:
125+
if use_pyplot:
126+
plt_fig, plt_axis = plt.subplots(figsize=fig_size, **pyplot_kwargs)
127+
else:
128+
plt_fig = Figure(figsize=fig_size)
129+
plt_axis = plt_fig.subplots(**pyplot_kwargs)
130+
return plt_fig, plt_axis
131+
# Figure.subplots returns Axes or array of Axes
132+
133+
112134
def _initialize_cmap_and_vmin_vmax(
113135
sign: str,
114136
) -> Tuple[Union[str, Colormap], float, float]:
@@ -338,16 +360,7 @@ def visualize_image_attr(
338360
>>> # Displays blended heat map visualization of computed attributions.
339361
>>> _ = visualize_image_attr(attribution, orig_image, "blended_heat_map")
340362
"""
341-
# Create plot if figure, axis not provided
342-
if plt_fig_axis is not None:
343-
plt_fig, plt_axis = plt_fig_axis
344-
else:
345-
if use_pyplot:
346-
plt_fig, plt_axis = plt.subplots(figsize=fig_size)
347-
else:
348-
plt_fig = Figure(figsize=fig_size)
349-
plt_axis = plt_fig.subplots()
350-
# Figure.subplots returns Axes or array of Axes
363+
plt_fig, plt_axis = _create_default_plot(plt_fig_axis, use_pyplot, fig_size)
351364

352365
if original_image is not None:
353366
if np.max(original_image) <= 1.0:
@@ -362,8 +375,10 @@ def visualize_image_attr(
362375
)
363376

364377
# Remove ticks and tick labels from plot.
365-
plt_axis.xaxis.set_ticks_position("none")
366-
plt_axis.yaxis.set_ticks_position("none")
378+
if plt_axis.xaxis is not None:
379+
plt_axis.xaxis.set_ticks_position("none")
380+
if plt_axis.yaxis is not None:
381+
plt_axis.yaxis.set_ticks_position("none")
367382
plt_axis.set_yticklabels([])
368383
plt_axis.set_xticklabels([])
369384
plt_axis.grid(visible=False)
@@ -528,6 +543,161 @@ def visualize_image_attr_multiple(
528543
return plt_fig, plt_axis
529544

530545

546+
def _plot_attrs_as_axvspan(
547+
# pyre-fixme[2]: Parameter must be annotated.
548+
attr_vals,
549+
# pyre-fixme[2]: Parameter must be annotated.
550+
x_vals,
551+
# pyre-fixme[2]: Parameter must be annotated.
552+
ax,
553+
# pyre-fixme[2]: Parameter must be annotated.
554+
x_values,
555+
# pyre-fixme[2]: Parameter must be annotated.
556+
cmap,
557+
# pyre-fixme[2]: Parameter must be annotated.
558+
cm_norm,
559+
# pyre-fixme[2]: Parameter must be annotated.
560+
alpha_overlay,
561+
) -> None:
562+
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
563+
half_col_width = (x_values[1] - x_values[0]) / 2.0
564+
565+
for icol, col_center in enumerate(x_vals):
566+
left = col_center - half_col_width
567+
right = col_center + half_col_width
568+
ax.axvspan(
569+
xmin=left,
570+
xmax=right,
571+
# pyre-fixme[29]: `Union[None, Colormap, str]` is not a function.
572+
facecolor=(cmap(cm_norm(attr_vals[icol]))), # type: ignore
573+
edgecolor=None,
574+
alpha=alpha_overlay,
575+
)
576+
577+
578+
def _visualize_overlay_individual(
579+
# pyre-fixme[2]: Parameter must be annotated.
580+
num_channels,
581+
# pyre-fixme[2]: Parameter must be annotated.
582+
plt_axis_list,
583+
# pyre-fixme[2]: Parameter must be annotated.
584+
x_values,
585+
# pyre-fixme[2]: Parameter must be annotated.
586+
data,
587+
# pyre-fixme[2]: Parameter must be annotated.
588+
channel_labels,
589+
# pyre-fixme[2]: Parameter must be annotated.
590+
norm_attr,
591+
# pyre-fixme[2]: Parameter must be annotated.
592+
cmap,
593+
# pyre-fixme[2]: Parameter must be annotated.
594+
cm_norm,
595+
# pyre-fixme[2]: Parameter must be annotated.
596+
alpha_overlay,
597+
# pyre-fixme[2]: Parameter must be annotated.
598+
**kwargs: Any,
599+
) -> None:
600+
# helper method for visualize_timeseries_attr
601+
pyplot_kwargs = kwargs.get("pyplot_kwargs", {})
602+
for chan in range(num_channels):
603+
plt_axis_list[chan].plot(x_values, data[chan, :], **pyplot_kwargs)
604+
if channel_labels is not None:
605+
plt_axis_list[chan].set_ylabel(channel_labels[chan])
606+
607+
_plot_attrs_as_axvspan(
608+
norm_attr[chan],
609+
x_values,
610+
plt_axis_list[chan],
611+
x_values,
612+
cmap,
613+
cm_norm,
614+
alpha_overlay,
615+
)
616+
617+
plt.subplots_adjust(hspace=0)
618+
pass
619+
620+
621+
def _visualize_overlay_combined(
622+
# pyre-fixme[2]: Parameter must be annotated.
623+
num_channels,
624+
# pyre-fixme[2]: Parameter must be annotated.
625+
plt_axis_list,
626+
# pyre-fixme[2]: Parameter must be annotated.
627+
x_values,
628+
# pyre-fixme[2]: Parameter must be annotated.
629+
data,
630+
# pyre-fixme[2]: Parameter must be annotated.
631+
channel_labels,
632+
# pyre-fixme[2]: Parameter must be annotated.
633+
norm_attr,
634+
# pyre-fixme[2]: Parameter must be annotated.
635+
cmap,
636+
# pyre-fixme[2]: Parameter must be annotated.
637+
cm_norm,
638+
# pyre-fixme[2]: Parameter must be annotated.
639+
alpha_overlay,
640+
**kwargs: Any,
641+
) -> None:
642+
pyplot_kwargs = kwargs.get("pyplot_kwargs", {})
643+
644+
cycler = plt.cycler("color", matplotlib.colormaps["Dark2"].colors) # type: ignore
645+
plt_axis_list[0].set_prop_cycle(cycler)
646+
647+
for chan in range(num_channels):
648+
label = channel_labels[chan] if channel_labels else None
649+
plt_axis_list[0].plot(x_values, data[chan, :], label=label, **pyplot_kwargs)
650+
651+
_plot_attrs_as_axvspan(
652+
norm_attr,
653+
x_values,
654+
plt_axis_list[0],
655+
x_values,
656+
cmap,
657+
cm_norm,
658+
alpha_overlay,
659+
)
660+
661+
plt_axis_list[0].legend(loc="best")
662+
663+
664+
def _visualize_colored_graph(
665+
# pyre-fixme[2]: Parameter must be annotated.
666+
num_channels,
667+
# pyre-fixme[2]: Parameter must be annotated.
668+
plt_axis_list,
669+
# pyre-fixme[2]: Parameter must be annotated.
670+
x_values,
671+
# pyre-fixme[2]: Parameter must be annotated.
672+
data,
673+
# pyre-fixme[2]: Parameter must be annotated.
674+
channel_labels,
675+
# pyre-fixme[2]: Parameter must be annotated.
676+
norm_attr,
677+
# pyre-fixme[2]: Parameter must be annotated.
678+
cmap,
679+
# pyre-fixme[2]: Parameter must be annotated.
680+
cm_norm,
681+
**kwargs: Any,
682+
) -> None:
683+
# helper method for visualize_timeseries_attr
684+
pyplot_kwargs = kwargs.get("pyplot_kwargs", {})
685+
for chan in range(num_channels):
686+
points = np.array([x_values, data[chan, :]]).T.reshape(-1, 1, 2)
687+
segments = np.concatenate([points[:-1], points[1:]], axis=1)
688+
689+
lc = LineCollection(segments, cmap=cmap, norm=cm_norm, **pyplot_kwargs)
690+
lc.set_array(norm_attr[chan, :])
691+
plt_axis_list[chan].add_collection(lc)
692+
plt_axis_list[chan].set_ylim(
693+
1.2 * np.min(data[chan, :]), 1.2 * np.max(data[chan, :])
694+
)
695+
if channel_labels is not None:
696+
plt_axis_list[chan].set_ylabel(channel_labels[chan])
697+
698+
plt.subplots_adjust(hspace=0)
699+
700+
531701
def visualize_timeseries_attr(
532702
attr: npt.NDArray,
533703
data: npt.NDArray,
@@ -686,8 +856,8 @@ def visualize_timeseries_attr(
686856

687857
num_subplots = num_channels
688858
if (
689-
TimeseriesVisualizationMethod[method]
690-
== TimeseriesVisualizationMethod.overlay_combined
859+
TimeseriesVisualizationMethod[method].value
860+
== TimeseriesVisualizationMethod.overlay_combined.value
691861
):
692862
num_subplots = 1
693863
attr = np.sum(attr, axis=0) # Merge attributions across channels
@@ -700,17 +870,9 @@ def visualize_timeseries_attr(
700870
x_values = np.arange(timeseries_length)
701871

702872
# Create plot if figure, axis not provided
703-
if plt_fig_axis is not None:
704-
plt_fig, plt_axis = plt_fig_axis
705-
else:
706-
if use_pyplot:
707-
plt_fig, plt_axis = plt.subplots( # type: ignore
708-
figsize=fig_size, nrows=num_subplots, sharex=True
709-
)
710-
else:
711-
plt_fig = Figure(figsize=fig_size)
712-
plt_axis = plt_fig.subplots(nrows=num_subplots, sharex=True) # type: ignore
713-
# Figure.subplots returns Axes or array of Axes
873+
plt_fig, plt_axis = _create_default_plot(
874+
plt_fig_axis, use_pyplot, fig_size, nrows=num_subplots, sharex=True
875+
)
714876

715877
if not isinstance(plt_axis, ndarray):
716878
plt_axis_list = np.array([plt_axis])
@@ -720,91 +882,30 @@ def visualize_timeseries_attr(
720882
norm_attr = _normalize_attr(attr, sign, outlier_perc, reduction_axis=None)
721883

722884
# Set default colormap and bounds based on sign.
723-
if VisualizeSign[sign] == VisualizeSign.all:
724-
default_cmap: Union[str, LinearSegmentedColormap] = (
725-
LinearSegmentedColormap.from_list("RdWhGn", ["red", "white", "green"])
726-
)
727-
vmin, vmax = -1, 1
728-
elif VisualizeSign[sign] == VisualizeSign.positive:
729-
default_cmap = "Greens"
730-
vmin, vmax = 0, 1
731-
elif VisualizeSign[sign] == VisualizeSign.negative:
732-
default_cmap = "Reds"
733-
vmin, vmax = 0, 1
734-
elif VisualizeSign[sign] == VisualizeSign.absolute_value:
735-
default_cmap = "Blues"
736-
vmin, vmax = 0, 1
737-
else:
738-
raise AssertionError("Visualize Sign type is not valid.")
885+
default_cmap, vmin, vmax = _initialize_cmap_and_vmin_vmax(sign)
739886
cmap = cmap if cmap is not None else default_cmap
740887
cmap = cm.get_cmap(cmap) # type: ignore
741888
cm_norm = colors.Normalize(vmin, vmax)
742889

743-
# pyre-fixme[53]: Captured variable `cm_norm` is not annotated.
744-
# pyre-fixme[2]: Parameter must be annotated.
745-
def _plot_attrs_as_axvspan(attr_vals, x_vals, ax) -> None:
746-
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
747-
half_col_width = (x_values[1] - x_values[0]) / 2.0
748-
for icol, col_center in enumerate(x_vals):
749-
left = col_center - half_col_width
750-
right = col_center + half_col_width
751-
ax.axvspan(
752-
xmin=left,
753-
xmax=right,
754-
# pyre-fixme[29]: `Union[None, Colormap, str]` is not a function.
755-
facecolor=(cmap(cm_norm(attr_vals[icol]))), # type: ignore
756-
edgecolor=None,
757-
alpha=alpha_overlay,
758-
)
759-
760-
if (
761-
TimeseriesVisualizationMethod[method]
762-
== TimeseriesVisualizationMethod.overlay_individual
763-
):
764-
for chan in range(num_channels):
765-
plt_axis_list[chan].plot(x_values, data[chan, :], **pyplot_kwargs)
766-
if channel_labels is not None:
767-
plt_axis_list[chan].set_ylabel(channel_labels[chan])
768-
769-
_plot_attrs_as_axvspan(norm_attr[chan], x_values, plt_axis_list[chan])
770-
771-
plt.subplots_adjust(hspace=0)
772-
773-
elif (
774-
TimeseriesVisualizationMethod[method]
775-
== TimeseriesVisualizationMethod.overlay_combined
776-
):
777-
# Dark colors are better in this case
778-
cycler = plt.cycler("color", matplotlib.colormaps["Dark2"]) # type: ignore
779-
plt_axis_list[0].set_prop_cycle(cycler)
780-
781-
for chan in range(num_channels):
782-
label = channel_labels[chan] if channel_labels else None
783-
plt_axis_list[0].plot(x_values, data[chan, :], label=label, **pyplot_kwargs)
784-
785-
_plot_attrs_as_axvspan(norm_attr, x_values, plt_axis_list[0])
786-
787-
plt_axis_list[0].legend(loc="best")
788-
789-
elif (
790-
TimeseriesVisualizationMethod[method]
791-
== TimeseriesVisualizationMethod.colored_graph
792-
):
793-
for chan in range(num_channels):
794-
points = np.array([x_values, data[chan, :]]).T.reshape(-1, 1, 2)
795-
segments = np.concatenate([points[:-1], points[1:]], axis=1)
796-
797-
lc = LineCollection(segments, cmap=cmap, norm=cm_norm, **pyplot_kwargs)
798-
lc.set_array(norm_attr[chan, :])
799-
plt_axis_list[chan].add_collection(lc)
800-
plt_axis_list[chan].set_ylim(
801-
1.2 * np.min(data[chan, :]), 1.2 * np.max(data[chan, :])
802-
)
803-
if channel_labels is not None:
804-
plt_axis_list[chan].set_ylabel(channel_labels[chan])
805-
806-
plt.subplots_adjust(hspace=0)
807-
890+
visualization_methods: Dict[str, Callable[..., Union[None, AxesImage]]] = {
891+
"overlay_individual": _visualize_overlay_individual,
892+
"overlay_combined": _visualize_overlay_combined,
893+
"colored_graph": _visualize_colored_graph,
894+
}
895+
kwargs = {
896+
"num_channels": num_channels,
897+
"plt_axis_list": plt_axis_list,
898+
"x_values": x_values,
899+
"data": data,
900+
"channel_labels": channel_labels,
901+
"norm_attr": norm_attr,
902+
"cmap": cmap,
903+
"cm_norm": cm_norm,
904+
"alpha_overlay": alpha_overlay,
905+
"pyplot_kwargs": pyplot_kwargs,
906+
}
907+
if method in visualization_methods:
908+
visualization_methods[method](**kwargs)
808909
else:
809910
raise AssertionError("Invalid visualization method: {}".format(method))
810911

0 commit comments

Comments
 (0)