From 1dcf6ca8f64762a7ab1327cc3c46a0a809acf886 Mon Sep 17 00:00:00 2001
From: George Katsikas <giorgakis.katsikas@gmail.com>
Date: Fri, 24 Jan 2025 15:33:23 +0100
Subject: [PATCH] =?UTF-8?q?refactor(graphs):=20=E2=99=BB=EF=B8=8F=20differ?=
 =?UTF-8?q?entiate=20aggragated=20value=20key?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Renames value_key to agg_value_key for the plot kinds where it is aggregated
---
 scipost_django/graphs/forms.py           |  3 +-
 scipost_django/graphs/graphs/plotkind.py | 44 +++++++++++++-----------
 2 files changed, 25 insertions(+), 22 deletions(-)

diff --git a/scipost_django/graphs/forms.py b/scipost_django/graphs/forms.py
index 4e9f87090..638f0a7ae 100644
--- a/scipost_django/graphs/forms.py
+++ b/scipost_django/graphs/forms.py
@@ -137,7 +137,8 @@ class PlotOptionsForm(InitialCoalescedForm):
     """
 
     FIELD_ADMISSIBLE_TYPES: dict[str, list[str]] = {
-        "value_key": ["int", "float"],
+        "value_key": ["int", "float", "date"],
+        "agg_value_key": ["int", "float"],
         "timeline_key": ["date", "datetime"],
         "group_key": ["str", "int", "country"],
         "country_key": ["country"],
diff --git a/scipost_django/graphs/graphs/plotkind.py b/scipost_django/graphs/graphs/plotkind.py
index 794b019d9..e4d6f6bbd 100644
--- a/scipost_django/graphs/graphs/plotkind.py
+++ b/scipost_django/graphs/graphs/plotkind.py
@@ -231,8 +231,8 @@ class MapPlot(PlotKind):
         from graphs.graphs import BASE_WORLD, OKLCH
 
         agg_func = self.options.get("agg_func", "count")
-        value_key_display = self.plotter.get_model_field_display(
-            self.options.get("value_key")
+        agg_value_key_display = self.plotter.get_model_field_display(
+            self.options.get("agg_value_key")
         )
         country_key_display = (
             self.plotter.get_model_field_display(self.options.get("country_key"))
@@ -241,12 +241,12 @@ class MapPlot(PlotKind):
         if agg_func == "count":
             plot_title = "{model} per {country}"
         else:
-            plot_title = "{agg_func} of {model}' {value_key_display} per {country}"
+            plot_title = "{agg_func} of {model}' {agg_value_key_display} per {country}"
 
         plot_title = plot_title.format(
             model=self.plotter.model._meta.verbose_name_plural,
             agg_func=agg_func,
-            value_key_display=value_key_display,
+            agg_value_key_display=agg_value_key_display,
             country=country_key_display,
         ).capitalize()
         color_plot_title, _ = plot_title.split(" per ")
@@ -300,7 +300,7 @@ class MapPlot(PlotKind):
         """
         qs = self.plotter.get_queryset()
 
-        value_key = self.options.get("value_key", "id") or "id"
+        agg_value_key = self.options.get("agg_value_key", "id") or "id"
         country_key = self.options.get("country_key")
 
         if country_key is None:
@@ -310,9 +310,9 @@ class MapPlot(PlotKind):
             case "count":
                 agg_func = Count("id")
             case "sum":
-                agg_func = Sum(value_key)
+                agg_func = Sum(agg_value_key)
             case "avg":
-                agg_func = Avg(value_key)
+                agg_func = Avg(agg_value_key)
             case _:
                 raise ValueError("Invalid aggregation function")
 
@@ -341,7 +341,9 @@ class MapPlot(PlotKind):
             required=False,
             initial="count",
         )
-        value_key = forms.ChoiceField(label="Value key", required=False, choices=[])
+        agg_value_key = forms.ChoiceField(
+            label="Agg Value key", required=False, choices=[]
+        )
         country_key = forms.ChoiceField(label="Country key", required=False, choices=[])
 
     @classmethod
@@ -349,7 +351,7 @@ class MapPlot(PlotKind):
         return Layout(
             Div(Field("country_key"), css_class="col-12"),
             Div(Field("agg_func"), css_class="col-6"),
-            Div(Field("value_key"), css_class="col-6"),
+            Div(Field("agg_value_key"), css_class="col-6"),
         )
 
 
@@ -374,17 +376,17 @@ class BarPlot(PlotKind):
         ):
             ax.set(**{f"{group_label_axis}label": group_key_label.capitalize()})
 
-        if value_key_name := self.plotter.get_model_field_display(
-            self.options.get("value_key")
+        if agg_value_key_name := self.plotter.get_model_field_display(
+            self.options.get("agg_value_key")
         ):
-            value_key_label = f"{agg_func} of {value_key_name}"
+            agg_value_key_label = f"{agg_func} of {agg_value_key_name}"
             if agg_func == "count":
                 # Simplify label and set locator to integer
-                value_key_label = "Count"
+                agg_value_key_label = "Count"
                 axis = getattr(ax, f"{value_label_axis}axis")
                 axis.get_major_locator().set_params(integer=True)
 
-            ax.set(**{f"{value_label_axis}label": value_key_label.capitalize()})
+            ax.set(**{f"{value_label_axis}label": agg_value_key_label.capitalize()})
 
         try:
             groups, vals = self.get_data()
@@ -409,7 +411,7 @@ class BarPlot(PlotKind):
         return fig
 
     def get_data(self):
-        value_key = self.options.get("value_key", "id") or "id"
+        agg_value_key = self.options.get("agg_value_key", "id") or "id"
         group_key = self.options.get("group_key")
         direction = self.options.get("direction", "vertical") or "vertical"
 
@@ -420,11 +422,11 @@ class BarPlot(PlotKind):
 
         match self.options.get("agg_func", "count"):
             case "count":
-                agg_func = Count("id")
+                agg_func = Count(group_key)
             case "sum":
-                agg_func = Sum(value_key)
+                agg_func = Sum(agg_value_key)
             case "avg":
-                agg_func = Avg(value_key)
+                agg_func = Avg(agg_value_key)
             case _:
                 raise ValueError("Invalid aggregation function")
 
@@ -475,8 +477,8 @@ class BarPlot(PlotKind):
         group_key = forms.ChoiceField(
             label="Group by key", required=False, initial="id", choices=[]
         )
-        value_key = forms.ChoiceField(
-            label="Value key", required=False, initial="id", choices=[]
+        agg_value_key = forms.ChoiceField(
+            label="Agg Value key", required=False, initial="id", choices=[]
         )
         agg_func = forms.ChoiceField(
             label="Aggregation function",
@@ -514,7 +516,7 @@ class BarPlot(PlotKind):
             ),
             Div(Field("group_key"), css_class="col-12"),
             Div(Field("agg_func"), css_class="col-6"),
-            Div(Field("value_key"), css_class="col-6"),
+            Div(Field("agg_value_key"), css_class="col-6"),
             Div(
                 Div(
                     Div(Field("order_by"), css_class="col-6"),
-- 
GitLab