SciPost Code Repository

Skip to content
Snippets Groups Projects
Commit ce69a4a5 authored by George Katsikas's avatar George Katsikas :goat:
Browse files

dynamically construct crispy layout in graphs plot

parent 21a64456
No related branches found
No related tags found
No related merge requests found
...@@ -8,7 +8,8 @@ from graphs.graphs.plotter import ModelFieldPlotter ...@@ -8,7 +8,8 @@ from graphs.graphs.plotter import ModelFieldPlotter
from .graphs import ALL_PLOTTERS, ALL_PLOT_KINDS, AVAILABLE_MPL_THEMES from .graphs import ALL_PLOTTERS, ALL_PLOT_KINDS, AVAILABLE_MPL_THEMES
import matplotlib.pyplot as plt from crispy_forms.helper import FormHelper, Layout
from crispy_forms.layout import Div, Field
class ModelFieldPlotterSelectForm(forms.Form): class ModelFieldPlotterSelectForm(forms.Form):
...@@ -92,6 +93,51 @@ class PlotOptionsForm(forms.Form): ...@@ -92,6 +93,51 @@ class PlotOptionsForm(forms.Form):
self.fields.update(self.plot_kind_select_form.fields) self.fields.update(self.plot_kind_select_form.fields)
self.fields.update(self.generic_plot_options_form.fields) self.fields.update(self.generic_plot_options_form.fields)
self.helper = FormHelper()
self.helper.layout = Layout()
def get_layout_field_names(layout):
"""Recurse through a layout to get all field names."""
field_names = []
for field in layout:
if isinstance(field, str):
field_names.append(field)
else:
field_names.extend(get_layout_field_names(field.fields))
return field_names
# Iterate over all forms and construct the form layout
# either by extending the layout with the preferred layout from the object class
# or by creating a row with all fields that are not already in the layout
for form, object_class in {
self.model_field_select_form: self.model_field_select_form.plotter.__class__,
self.plot_kind_select_form: self.plot_kind_select_form.kind_class,
self.generic_plot_options_form: None,
}.items():
layout = Layout()
if object_class not in (None, None.__class__):
principal_field_name = next(iter(form.fields.keys()))
layout.append(Div(Field(principal_field_name), css_class="col-12"))
row_constructor = getattr(
object_class, "get_plot_options_form_layout_row_content"
)
if row_constructor:
layout.extend(row_constructor())
layout.extend(
[
Div(Field(field_name), css_class="col-12")
for field_name in form.fields.keys()
if field_name not in get_layout_field_names(layout)
],
)
self.helper.layout.append(layout)
self.helper.all().wrap(Div, css_class="row")
def clean(self): def clean(self):
cleaned_data = super().clean() cleaned_data = super().clean()
......
...@@ -7,6 +7,7 @@ from matplotlib.figure import Figure ...@@ -7,6 +7,7 @@ from matplotlib.figure import Figure
import pandas as pd import pandas as pd
from .options import BaseOptions from .options import BaseOptions
from crispy_forms.layout import Layout, Div, Field
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
...@@ -69,10 +70,23 @@ class PlotKind: ...@@ -69,10 +70,23 @@ class PlotKind:
ax = fig.add_subplot(111) ax = fig.add_subplot(111)
ax.set_title(f"{self.get_name()} plot of {self.plotter.model.__name__}") ax.set_title(f"{self.get_name()} plot of {self.plotter.model.__name__}")
x, y = self.get_data() try:
ax.plot(x, y) x, y = self.get_data()
ax.plot(x, y)
except ValueError as e:
self.display_plotting_error(ax)
return fig return fig
def display_plotting_error(self, ax):
ax.text(0.5, 0.5, f"No data to plot", ha="center", va="center")
ax.grid(False)
ax.axis("off")
@classmethod
def get_plot_options_form_layout_row_content(cls):
return Div()
class TimelinePlot(PlotKind): class TimelinePlot(PlotKind):
name = "timeline" name = "timeline"
...@@ -124,6 +138,38 @@ class TimelinePlot(PlotKind): ...@@ -124,6 +138,38 @@ class TimelinePlot(PlotKind):
widget=forms.DateTimeInput(attrs={"type": "date"}), widget=forms.DateTimeInput(attrs={"type": "date"}),
) )
@classmethod
def get_plot_options_form_layout_row_content(cls):
layout = Layout(
Div(Field("y_key"), css_class="col-12"),
Div(Field("x_lim_min"), css_class="col-6"),
Div(Field("x_lim_max"), css_class="col-6"),
)
# Prefix every field in the layout with the prefix
def prefix_field(field):
"""
Recursively prefix the fields in a layout.
Return type is irrelevant, as it modifies the argument directly.
"""
contained_fields = getattr(field, "fields", None)
if contained_fields is None:
return
# If the crispy field is a Field type with a single string identifier, prefix it
if (
isinstance(field, Field)
and len(contained_fields) == 1
and isinstance(field_key := contained_fields[0], str)
):
field.fields = [cls.Options.prefix + field_key]
else:
return [prefix_field(f) for f in contained_fields]
prefix_field(layout)
return layout
class MapPlot(PlotKind): class MapPlot(PlotKind):
name = "map" name = "map"
......
...@@ -19,6 +19,8 @@ from submissions.models import Report, Submission ...@@ -19,6 +19,8 @@ from submissions.models import Report, Submission
from .options import BaseOptions from .options import BaseOptions
from crispy_forms.layout import Div
OptionDict = dict[str, Any] OptionDict = dict[str, Any]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -80,10 +82,15 @@ class ModelFieldPlotter(ABC): ...@@ -80,10 +82,15 @@ class ModelFieldPlotter(ABC):
) )
fig = kind.plot() fig = kind.plot()
fig.suptitle(options.get("title", None)) if title := options.get("title", None):
fig.axes[0].set_title(title)
return fig return fig
@classmethod
def get_plot_options_form_layout_row_content(cls):
return Div()
class PublicationPlotter(ModelFieldPlotter): class PublicationPlotter(ModelFieldPlotter):
model = Publication model = Publication
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment