diff --git a/scipost_django/graphs/forms.py b/scipost_django/graphs/forms.py index 4b6950b1e85c14e036d23122b6cab0ec671aed77..1a2ceb9e4a786c709e02a96dbcbdecf5a1e0a31c 100644 --- a/scipost_django/graphs/forms.py +++ b/scipost_django/graphs/forms.py @@ -8,7 +8,8 @@ from graphs.graphs.plotter import ModelFieldPlotter from .graphs import ALL_PLOTTERS, ALL_PLOT_KINDS, AVAILABLE_MPL_THEMES -import matplotlib.pyplot as plt +from crispy_forms.helper import FormHelper, Layout +from crispy_forms.layout import Div, Field class ModelFieldPlotterSelectForm(forms.Form): @@ -92,6 +93,51 @@ class PlotOptionsForm(forms.Form): self.fields.update(self.plot_kind_select_form.fields) self.fields.update(self.generic_plot_options_form.fields) + self.helper = FormHelper() + self.helper.layout = Layout() + + def get_layout_field_names(layout): + """Recurse through a layout to get all field names.""" + field_names = [] + for field in layout: + if isinstance(field, str): + field_names.append(field) + else: + field_names.extend(get_layout_field_names(field.fields)) + return field_names + + # Iterate over all forms and construct the form layout + # either by extending the layout with the preferred layout from the object class + # or by creating a row with all fields that are not already in the layout + for form, object_class in { + self.model_field_select_form: self.model_field_select_form.plotter.__class__, + self.plot_kind_select_form: self.plot_kind_select_form.kind_class, + self.generic_plot_options_form: None, + }.items(): + + layout = Layout() + if object_class not in (None, None.__class__): + principal_field_name = next(iter(form.fields.keys())) + layout.append(Div(Field(principal_field_name), css_class="col-12")) + + row_constructor = getattr( + object_class, "get_plot_options_form_layout_row_content" + ) + if row_constructor: + layout.extend(row_constructor()) + + layout.extend( + [ + Div(Field(field_name), css_class="col-12") + for field_name in form.fields.keys() + if field_name not in get_layout_field_names(layout) + ], + ) + + self.helper.layout.append(layout) + + self.helper.all().wrap(Div, css_class="row") + def clean(self): cleaned_data = super().clean() diff --git a/scipost_django/graphs/graphs/plotkind.py b/scipost_django/graphs/graphs/plotkind.py index babf740f2e376046368d5e8059d2c48c8acb42f8..3503c7bd36e5cdf077f958124f97b7b0eca506c6 100644 --- a/scipost_django/graphs/graphs/plotkind.py +++ b/scipost_django/graphs/graphs/plotkind.py @@ -7,6 +7,7 @@ from matplotlib.figure import Figure import pandas as pd from .options import BaseOptions +from crispy_forms.layout import Layout, Div, Field from typing import TYPE_CHECKING, Any @@ -69,10 +70,23 @@ class PlotKind: ax = fig.add_subplot(111) ax.set_title(f"{self.get_name()} plot of {self.plotter.model.__name__}") - x, y = self.get_data() - ax.plot(x, y) + try: + x, y = self.get_data() + ax.plot(x, y) + except ValueError as e: + self.display_plotting_error(ax) + return fig + def display_plotting_error(self, ax): + ax.text(0.5, 0.5, f"No data to plot", ha="center", va="center") + ax.grid(False) + ax.axis("off") + + @classmethod + def get_plot_options_form_layout_row_content(cls): + return Div() + class TimelinePlot(PlotKind): name = "timeline" @@ -124,6 +138,38 @@ class TimelinePlot(PlotKind): widget=forms.DateTimeInput(attrs={"type": "date"}), ) + @classmethod + def get_plot_options_form_layout_row_content(cls): + layout = Layout( + Div(Field("y_key"), css_class="col-12"), + Div(Field("x_lim_min"), css_class="col-6"), + Div(Field("x_lim_max"), css_class="col-6"), + ) + + # Prefix every field in the layout with the prefix + def prefix_field(field): + """ + Recursively prefix the fields in a layout. + Return type is irrelevant, as it modifies the argument directly. + """ + contained_fields = getattr(field, "fields", None) + if contained_fields is None: + return + + # If the crispy field is a Field type with a single string identifier, prefix it + if ( + isinstance(field, Field) + and len(contained_fields) == 1 + and isinstance(field_key := contained_fields[0], str) + ): + field.fields = [cls.Options.prefix + field_key] + else: + return [prefix_field(f) for f in contained_fields] + + prefix_field(layout) + + return layout + class MapPlot(PlotKind): name = "map" diff --git a/scipost_django/graphs/graphs/plotter.py b/scipost_django/graphs/graphs/plotter.py index d7d37e110e01cb62503a37a59b817d1fede670b7..aeed2979a74903fdc7e98ef2902de1594692f432 100644 --- a/scipost_django/graphs/graphs/plotter.py +++ b/scipost_django/graphs/graphs/plotter.py @@ -19,6 +19,8 @@ from submissions.models import Report, Submission from .options import BaseOptions +from crispy_forms.layout import Div + OptionDict = dict[str, Any] if TYPE_CHECKING: @@ -80,10 +82,15 @@ class ModelFieldPlotter(ABC): ) fig = kind.plot() - fig.suptitle(options.get("title", None)) + if title := options.get("title", None): + fig.axes[0].set_title(title) return fig + @classmethod + def get_plot_options_form_layout_row_content(cls): + return Div() + class PublicationPlotter(ModelFieldPlotter): model = Publication