From 1dcf6ca8f64762a7ab1327cc3c46a0a809acf886 Mon Sep 17 00:00:00 2001 From: George Katsikas <giorgakis.katsikas@gmail.com> Date: Fri, 24 Jan 2025 15:33:23 +0100 Subject: [PATCH] =?UTF-8?q?refactor(graphs):=20=E2=99=BB=EF=B8=8F=20differ?= =?UTF-8?q?entiate=20aggragated=20value=20key?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Renames value_key to agg_value_key for the plot kinds where it is aggregated --- scipost_django/graphs/forms.py | 3 +- scipost_django/graphs/graphs/plotkind.py | 44 +++++++++++++----------- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/scipost_django/graphs/forms.py b/scipost_django/graphs/forms.py index 4e9f87090..638f0a7ae 100644 --- a/scipost_django/graphs/forms.py +++ b/scipost_django/graphs/forms.py @@ -137,7 +137,8 @@ class PlotOptionsForm(InitialCoalescedForm): """ FIELD_ADMISSIBLE_TYPES: dict[str, list[str]] = { - "value_key": ["int", "float"], + "value_key": ["int", "float", "date"], + "agg_value_key": ["int", "float"], "timeline_key": ["date", "datetime"], "group_key": ["str", "int", "country"], "country_key": ["country"], diff --git a/scipost_django/graphs/graphs/plotkind.py b/scipost_django/graphs/graphs/plotkind.py index 794b019d9..e4d6f6bbd 100644 --- a/scipost_django/graphs/graphs/plotkind.py +++ b/scipost_django/graphs/graphs/plotkind.py @@ -231,8 +231,8 @@ class MapPlot(PlotKind): from graphs.graphs import BASE_WORLD, OKLCH agg_func = self.options.get("agg_func", "count") - value_key_display = self.plotter.get_model_field_display( - self.options.get("value_key") + agg_value_key_display = self.plotter.get_model_field_display( + self.options.get("agg_value_key") ) country_key_display = ( self.plotter.get_model_field_display(self.options.get("country_key")) @@ -241,12 +241,12 @@ class MapPlot(PlotKind): if agg_func == "count": plot_title = "{model} per {country}" else: - plot_title = "{agg_func} of {model}' {value_key_display} per {country}" + plot_title = "{agg_func} of {model}' {agg_value_key_display} per {country}" plot_title = plot_title.format( model=self.plotter.model._meta.verbose_name_plural, agg_func=agg_func, - value_key_display=value_key_display, + agg_value_key_display=agg_value_key_display, country=country_key_display, ).capitalize() color_plot_title, _ = plot_title.split(" per ") @@ -300,7 +300,7 @@ class MapPlot(PlotKind): """ qs = self.plotter.get_queryset() - value_key = self.options.get("value_key", "id") or "id" + agg_value_key = self.options.get("agg_value_key", "id") or "id" country_key = self.options.get("country_key") if country_key is None: @@ -310,9 +310,9 @@ class MapPlot(PlotKind): case "count": agg_func = Count("id") case "sum": - agg_func = Sum(value_key) + agg_func = Sum(agg_value_key) case "avg": - agg_func = Avg(value_key) + agg_func = Avg(agg_value_key) case _: raise ValueError("Invalid aggregation function") @@ -341,7 +341,9 @@ class MapPlot(PlotKind): required=False, initial="count", ) - value_key = forms.ChoiceField(label="Value key", required=False, choices=[]) + agg_value_key = forms.ChoiceField( + label="Agg Value key", required=False, choices=[] + ) country_key = forms.ChoiceField(label="Country key", required=False, choices=[]) @classmethod @@ -349,7 +351,7 @@ class MapPlot(PlotKind): return Layout( Div(Field("country_key"), css_class="col-12"), Div(Field("agg_func"), css_class="col-6"), - Div(Field("value_key"), css_class="col-6"), + Div(Field("agg_value_key"), css_class="col-6"), ) @@ -374,17 +376,17 @@ class BarPlot(PlotKind): ): ax.set(**{f"{group_label_axis}label": group_key_label.capitalize()}) - if value_key_name := self.plotter.get_model_field_display( - self.options.get("value_key") + if agg_value_key_name := self.plotter.get_model_field_display( + self.options.get("agg_value_key") ): - value_key_label = f"{agg_func} of {value_key_name}" + agg_value_key_label = f"{agg_func} of {agg_value_key_name}" if agg_func == "count": # Simplify label and set locator to integer - value_key_label = "Count" + agg_value_key_label = "Count" axis = getattr(ax, f"{value_label_axis}axis") axis.get_major_locator().set_params(integer=True) - ax.set(**{f"{value_label_axis}label": value_key_label.capitalize()}) + ax.set(**{f"{value_label_axis}label": agg_value_key_label.capitalize()}) try: groups, vals = self.get_data() @@ -409,7 +411,7 @@ class BarPlot(PlotKind): return fig def get_data(self): - value_key = self.options.get("value_key", "id") or "id" + agg_value_key = self.options.get("agg_value_key", "id") or "id" group_key = self.options.get("group_key") direction = self.options.get("direction", "vertical") or "vertical" @@ -420,11 +422,11 @@ class BarPlot(PlotKind): match self.options.get("agg_func", "count"): case "count": - agg_func = Count("id") + agg_func = Count(group_key) case "sum": - agg_func = Sum(value_key) + agg_func = Sum(agg_value_key) case "avg": - agg_func = Avg(value_key) + agg_func = Avg(agg_value_key) case _: raise ValueError("Invalid aggregation function") @@ -475,8 +477,8 @@ class BarPlot(PlotKind): group_key = forms.ChoiceField( label="Group by key", required=False, initial="id", choices=[] ) - value_key = forms.ChoiceField( - label="Value key", required=False, initial="id", choices=[] + agg_value_key = forms.ChoiceField( + label="Agg Value key", required=False, initial="id", choices=[] ) agg_func = forms.ChoiceField( label="Aggregation function", @@ -514,7 +516,7 @@ class BarPlot(PlotKind): ), Div(Field("group_key"), css_class="col-12"), Div(Field("agg_func"), css_class="col-6"), - Div(Field("value_key"), css_class="col-6"), + Div(Field("agg_value_key"), css_class="col-6"), Div( Div( Div(Field("order_by"), css_class="col-6"), -- GitLab