diff --git a/scipost_django/graphs/forms.py b/scipost_django/graphs/forms.py index a185911f2f96ab588da8419c9a79f24b226eb20a..4b6950b1e85c14e036d23122b6cab0ec671aed77 100644 --- a/scipost_django/graphs/forms.py +++ b/scipost_django/graphs/forms.py @@ -50,17 +50,11 @@ class PlotKindSelectForm(forms.Form): super().__init__(*args, **kwargs) # If a plot kind is already selected, populate the form with its options - if plot_kind := PlotKind.from_name(self.data.get("plot_kind", None), self.data): - self.fields.update(plot_kind.Options.get_option_fields()) + plot_kind_class_name = self.data.get("plot_kind", None) + if plot_kind_class := PlotKind.class_from_name(plot_kind_class_name): + self.fields.update(plot_kind_class.Options.get_option_fields()) - self.kind = plot_kind - - def clean(self): - cleaned_data = super().clean() - - # Recreate plot kind with cleaned data - self.kind = PlotKind.from_name(self.data.get("plot_kind", None), cleaned_data) - return cleaned_data + self.kind_class = plot_kind_class class GenericPlotOptionsForm(forms.Form): diff --git a/scipost_django/graphs/graphs/plotkind.py b/scipost_django/graphs/graphs/plotkind.py index 02b3817ef0247fe73d72ae7390280a82fda0f515..babf740f2e376046368d5e8059d2c48c8acb42f8 100644 --- a/scipost_django/graphs/graphs/plotkind.py +++ b/scipost_django/graphs/graphs/plotkind.py @@ -25,15 +25,18 @@ class PlotKind: prefix = "plot_kind_" pass - def __init__(self, options: "OptionDict" = {}): + def __init__(self, plotter: "ModelFieldPlotter", options: "OptionDict" = {}): + self.plotter = plotter self.options = self.Options.parse_prefixed_options(options) @classmethod - def from_name(cls, name: str, *args, **kwargs): + def class_from_name(cls, name: str): from graphs.graphs import ALL_PLOT_KINDS if cls_name := ALL_PLOT_KINDS.get(name, None): - return cls_name(*args, **kwargs) + return cls_name + + return PlotKind @classmethod def get_name(cls) -> str: @@ -49,24 +52,24 @@ class PlotKind: """ return Figure(**kwargs) - def get_data(self, plotter: "ModelFieldPlotter") -> tuple[list[int], list[Any]]: + def get_data(self) -> tuple[list[int], list[Any]]: """ Obtain the values to plot from the queryset. """ - qs = plotter.get_queryset() + qs = self.plotter.get_queryset() y = qs.values_list("id", flat=True) x = list(range(len(y))) return x, y - def plot(self, plotter: "ModelFieldPlotter"): + def plot(self): """ Plot the data on a the figure. """ fig = self.get_figure() ax = fig.add_subplot(111) - ax.set_title(f"{self.get_name()} plot of {plotter.model.__name__}") + ax.set_title(f"{self.get_name()} plot of {self.plotter.model.__name__}") - x, y = self.get_data(plotter) + x, y = self.get_data() ax.plot(x, y) return fig @@ -74,36 +77,36 @@ class PlotKind: class TimelinePlot(PlotKind): name = "timeline" - def plot(self, plotter: "ModelFieldPlotter"): - fig = super().plot(plotter) + def plot(self): + fig = super().plot() ax = fig.get_axes()[0] - ax.set_xlabel(plotter.date_key) + ax.set_xlabel(self.plotter.date_key) ax.set_ylabel(self.options.get("y_key", "id")) return fig - def get_data(self, plotter: "ModelFieldPlotter", **kwargs): + def get_data(self): y_key = self.options.get("y_key", "id") or "id" # Filter the queryset to only include entries with a date and a y value query_filters = Q( **{ - plotter.date_key + "__isnull": False, + self.plotter.date_key + "__isnull": False, y_key + "__isnull": False, } ) # Filter the queryset according to the date limits if they are set if x_lim_min := self.options.get("x_lim_min", None): - query_filters &= Q(**{plotter.date_key + "__gte": x_lim_min}) + query_filters &= Q(**{self.plotter.date_key + "__gte": x_lim_min}) if x_lim_max := self.options.get("x_lim_max", None): - query_filters &= Q(**{plotter.date_key + "__lte": x_lim_max}) + query_filters &= Q(**{self.plotter.date_key + "__lte": x_lim_max}) - qs = plotter.get_queryset() + qs = self.plotter.get_queryset() qs = qs.filter(query_filters) - qs = qs.order_by(plotter.date_key) + qs = qs.order_by(self.plotter.date_key) - x, y = zip(*qs.values_list(plotter.date_key, y_key)) + x, y = zip(*qs.values_list(self.plotter.date_key, y_key)) return x, y @@ -158,7 +161,7 @@ class MapPlot(PlotKind): cax0.tick_params(axis="x", length=1.5, direction="out", which="minor") cax0.grid(False) - def plot(self, plotter: "ModelFieldPlotter"): + def plot(self): from graphs.graphs import BASE_WORLD, OKLCH from matplotlib.colors import LinearSegmentedColormap, LogNorm @@ -166,7 +169,7 @@ class MapPlot(PlotKind): self.draw_colorbar(fig) ax, cax, _ = fig.get_axes() - countries, count = self.get_data(plotter) + countries, count = self.get_data() df_counts = pd.DataFrame({"ISO_A2_EH": countries, "count": count}) vmax = df_counts["count"].max() BASE_WORLD.merge(df_counts, left_on="ISO_A2_EH", right_on="ISO_A2_EH").plot( @@ -188,16 +191,19 @@ class MapPlot(PlotKind): return fig - def get_data(self, plotter: "ModelFieldPlotter", **kwargs): - qs = plotter.get_queryset() + def get_data(self): + """ + Return the a tuple of lists of countries and their counts. + """ + qs = self.plotter.get_queryset() count_key = self.options.get("count_key", "id") group_by_country_count = ( - qs.filter(Q(**{plotter.country_key + "__isnull": False})) - .values(plotter.country_key) + qs.filter(Q(**{self.plotter.country_key + "__isnull": False})) + .values(self.plotter.country_key) .annotate(count=Count(count_key)) ) countries, count = zip( - *group_by_country_count.values_list(plotter.country_key, "count") + *group_by_country_count.values_list(self.plotter.country_key, "count") ) return countries, count diff --git a/scipost_django/graphs/graphs/plotter.py b/scipost_django/graphs/graphs/plotter.py index b8f992b1154ba51197587ab808ac5969a665e5de..d7d37e110e01cb62503a37a59b817d1fede670b7 100644 --- a/scipost_django/graphs/graphs/plotter.py +++ b/scipost_django/graphs/graphs/plotter.py @@ -79,7 +79,7 @@ class ModelFieldPlotter(ABC): ] ) - fig = kind.plot(plotter=self) + fig = kind.plot() fig.suptitle(options.get("title", None)) return fig diff --git a/scipost_django/graphs/views.py b/scipost_django/graphs/views.py index dd50620b67ab13e1c5961f1798fb6de3c06e7b4c..aa401b6e1060b83ff82a188f190c23e85e896690 100644 --- a/scipost_django/graphs/views.py +++ b/scipost_django/graphs/views.py @@ -50,7 +50,10 @@ class PlotView(View): cleaned_data = form.clean() self.plotter = form.model_field_select_form.plotter - self.kind = form.plot_kind_select_form.kind + self.kind = form.plot_kind_select_form.kind_class( + options=form.plot_kind_select_form.cleaned_data, + plotter=self.plotter, + ) self.plot_options = { "plot_kind": {},