SciPost Code Repository

Skip to content
Snippets Groups Projects
Commit be6c1f82 authored by George Katsikas's avatar George Katsikas :goat:
Browse files

refactor(graphs): :recycle:️ extract field prefix to form

parent 377ebbad
No related branches found
No related tags found
No related merge requests found
...@@ -9,7 +9,7 @@ from graphs.graphs.plotter import ModelFieldPlotter ...@@ -9,7 +9,7 @@ from graphs.graphs.plotter import ModelFieldPlotter
from .graphs import ALL_PLOTTERS, ALL_PLOT_KINDS, AVAILABLE_MPL_THEMES from .graphs import ALL_PLOTTERS, ALL_PLOT_KINDS, AVAILABLE_MPL_THEMES
from crispy_forms.helper import FormHelper, Layout from crispy_forms.helper import FormHelper, Layout
from crispy_forms.layout import Div, Field from crispy_forms.layout import LayoutObject, Div, Field
class ModelFieldPlotterSelectForm(forms.Form): class ModelFieldPlotterSelectForm(forms.Form):
...@@ -98,7 +98,8 @@ class PlotOptionsForm(forms.Form): ...@@ -98,7 +98,8 @@ class PlotOptionsForm(forms.Form):
def get_layout_field_names(layout): def get_layout_field_names(layout):
"""Recurse through a layout to get all field names.""" """Recurse through a layout to get all field names."""
field_names = [] field_names: list[str] = []
field: LayoutObject | str
for field in layout: for field in layout:
if isinstance(field, str): if isinstance(field, str):
field_names.append(field) field_names.append(field)
...@@ -106,6 +107,25 @@ class PlotOptionsForm(forms.Form): ...@@ -106,6 +107,25 @@ class PlotOptionsForm(forms.Form):
field_names.extend(get_layout_field_names(field.fields)) field_names.extend(get_layout_field_names(field.fields))
return field_names return field_names
def prefix_layout_fields(prefix: str, field: LayoutObject):
"""
Recursively prefix the fields in a layout.
Return type is irrelevant, as it modifies the argument directly.
"""
if (contained_fields := getattr(field, "fields", None)) is None:
return
# If the crispy field is a Field type with a single string identifier, prefix it
if (
isinstance(field, Field)
and len(contained_fields) == 1
and isinstance(field_key := contained_fields[0], str)
):
field.fields = [prefix + field_key]
else:
[prefix_layout_fields(prefix, f) for f in contained_fields]
# Iterate over all forms and construct the form layout # Iterate over all forms and construct the form layout
# either by extending the layout with the preferred layout from the object class # either by extending the layout with the preferred layout from the object class
# or by creating a row with all fields that are not already in the layout # or by creating a row with all fields that are not already in the layout
...@@ -121,10 +141,18 @@ class PlotOptionsForm(forms.Form): ...@@ -121,10 +141,18 @@ class PlotOptionsForm(forms.Form):
layout.append(Div(Field(principal_field_name), css_class="col-12")) layout.append(Div(Field(principal_field_name), css_class="col-12"))
row_constructor = getattr( row_constructor = getattr(
object_class, "get_plot_options_form_layout_row_content" object_class, "get_plot_options_form_layout_row_content", None
) )
if row_constructor: if row_constructor:
layout.extend(row_constructor()) try:
object_class_prefix = object_class.Options.prefix or ""
except AttributeError:
object_class_prefix = ""
fields = row_constructor()
# In-place prefixing of the layout-field names
prefix_layout_fields(object_class_prefix, fields)
layout.extend(fields)
layout.extend( layout.extend(
[ [
......
...@@ -7,7 +7,7 @@ from matplotlib.figure import Figure ...@@ -7,7 +7,7 @@ from matplotlib.figure import Figure
import pandas as pd import pandas as pd
from .options import BaseOptions from .options import BaseOptions
from crispy_forms.layout import Layout, Div, Field from crispy_forms.layout import LayoutObject, Layout, Div, Field
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
...@@ -84,7 +84,7 @@ class PlotKind: ...@@ -84,7 +84,7 @@ class PlotKind:
ax.axis("off") ax.axis("off")
@classmethod @classmethod
def get_plot_options_form_layout_row_content(cls): def get_plot_options_form_layout_row_content(cls) -> LayoutObject:
return Div() return Div()
...@@ -140,36 +140,12 @@ class TimelinePlot(PlotKind): ...@@ -140,36 +140,12 @@ class TimelinePlot(PlotKind):
@classmethod @classmethod
def get_plot_options_form_layout_row_content(cls): def get_plot_options_form_layout_row_content(cls):
layout = Layout( return Layout(
Div(Field("y_key"), css_class="col-12"), Div(Field("y_key"), css_class="col-12"),
Div(Field("x_lim_min"), css_class="col-6"), Div(Field("x_lim_min"), css_class="col-6"),
Div(Field("x_lim_max"), css_class="col-6"), Div(Field("x_lim_max"), css_class="col-6"),
) )
# Prefix every field in the layout with the prefix
def prefix_field(field):
"""
Recursively prefix the fields in a layout.
Return type is irrelevant, as it modifies the argument directly.
"""
contained_fields = getattr(field, "fields", None)
if contained_fields is None:
return
# If the crispy field is a Field type with a single string identifier, prefix it
if (
isinstance(field, Field)
and len(contained_fields) == 1
and isinstance(field_key := contained_fields[0], str)
):
field.fields = [cls.Options.prefix + field_key]
else:
return [prefix_field(f) for f in contained_fields]
prefix_field(layout)
return layout
class MapPlot(PlotKind): class MapPlot(PlotKind):
name = "map" name = "map"
...@@ -303,31 +279,7 @@ class MapPlot(PlotKind): ...@@ -303,31 +279,7 @@ class MapPlot(PlotKind):
@classmethod @classmethod
def get_plot_options_form_layout_row_content(cls): def get_plot_options_form_layout_row_content(cls):
layout = Layout( return Layout(
Div(Field("agg_func"), css_class="col-6"), Div(Field("agg_func"), css_class="col-6"),
Div(Field("agg_key"), css_class="col-6"), Div(Field("agg_key"), css_class="col-6"),
) )
# Prefix every field in the layout with the prefix
def prefix_field(field):
"""
Recursively prefix the fields in a layout.
Return type is irrelevant, as it modifies the argument directly.
"""
contained_fields = getattr(field, "fields", None)
if contained_fields is None:
return
# If the crispy field is a Field type with a single string identifier, prefix it
if (
isinstance(field, Field)
and len(contained_fields) == 1
and isinstance(field_key := contained_fields[0], str)
):
field.fields = [cls.Options.prefix + field_key]
else:
return [prefix_field(f) for f in contained_fields]
prefix_field(layout)
return layout
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment