SciPost Code Repository

Skip to content
Snippets Groups Projects
Commit 1dcf6ca8 authored by George Katsikas's avatar George Katsikas :goat:
Browse files

refactor(graphs): :recycle:️ differentiate aggragated value key

Renames value_key to agg_value_key for the plot kinds where it is aggregated
parent 2d88f898
No related branches found
No related tags found
No related merge requests found
......@@ -137,7 +137,8 @@ class PlotOptionsForm(InitialCoalescedForm):
"""
FIELD_ADMISSIBLE_TYPES: dict[str, list[str]] = {
"value_key": ["int", "float"],
"value_key": ["int", "float", "date"],
"agg_value_key": ["int", "float"],
"timeline_key": ["date", "datetime"],
"group_key": ["str", "int", "country"],
"country_key": ["country"],
......
......@@ -231,8 +231,8 @@ class MapPlot(PlotKind):
from graphs.graphs import BASE_WORLD, OKLCH
agg_func = self.options.get("agg_func", "count")
value_key_display = self.plotter.get_model_field_display(
self.options.get("value_key")
agg_value_key_display = self.plotter.get_model_field_display(
self.options.get("agg_value_key")
)
country_key_display = (
self.plotter.get_model_field_display(self.options.get("country_key"))
......@@ -241,12 +241,12 @@ class MapPlot(PlotKind):
if agg_func == "count":
plot_title = "{model} per {country}"
else:
plot_title = "{agg_func} of {model}' {value_key_display} per {country}"
plot_title = "{agg_func} of {model}' {agg_value_key_display} per {country}"
plot_title = plot_title.format(
model=self.plotter.model._meta.verbose_name_plural,
agg_func=agg_func,
value_key_display=value_key_display,
agg_value_key_display=agg_value_key_display,
country=country_key_display,
).capitalize()
color_plot_title, _ = plot_title.split(" per ")
......@@ -300,7 +300,7 @@ class MapPlot(PlotKind):
"""
qs = self.plotter.get_queryset()
value_key = self.options.get("value_key", "id") or "id"
agg_value_key = self.options.get("agg_value_key", "id") or "id"
country_key = self.options.get("country_key")
if country_key is None:
......@@ -310,9 +310,9 @@ class MapPlot(PlotKind):
case "count":
agg_func = Count("id")
case "sum":
agg_func = Sum(value_key)
agg_func = Sum(agg_value_key)
case "avg":
agg_func = Avg(value_key)
agg_func = Avg(agg_value_key)
case _:
raise ValueError("Invalid aggregation function")
......@@ -341,7 +341,9 @@ class MapPlot(PlotKind):
required=False,
initial="count",
)
value_key = forms.ChoiceField(label="Value key", required=False, choices=[])
agg_value_key = forms.ChoiceField(
label="Agg Value key", required=False, choices=[]
)
country_key = forms.ChoiceField(label="Country key", required=False, choices=[])
@classmethod
......@@ -349,7 +351,7 @@ class MapPlot(PlotKind):
return Layout(
Div(Field("country_key"), css_class="col-12"),
Div(Field("agg_func"), css_class="col-6"),
Div(Field("value_key"), css_class="col-6"),
Div(Field("agg_value_key"), css_class="col-6"),
)
......@@ -374,17 +376,17 @@ class BarPlot(PlotKind):
):
ax.set(**{f"{group_label_axis}label": group_key_label.capitalize()})
if value_key_name := self.plotter.get_model_field_display(
self.options.get("value_key")
if agg_value_key_name := self.plotter.get_model_field_display(
self.options.get("agg_value_key")
):
value_key_label = f"{agg_func} of {value_key_name}"
agg_value_key_label = f"{agg_func} of {agg_value_key_name}"
if agg_func == "count":
# Simplify label and set locator to integer
value_key_label = "Count"
agg_value_key_label = "Count"
axis = getattr(ax, f"{value_label_axis}axis")
axis.get_major_locator().set_params(integer=True)
ax.set(**{f"{value_label_axis}label": value_key_label.capitalize()})
ax.set(**{f"{value_label_axis}label": agg_value_key_label.capitalize()})
try:
groups, vals = self.get_data()
......@@ -409,7 +411,7 @@ class BarPlot(PlotKind):
return fig
def get_data(self):
value_key = self.options.get("value_key", "id") or "id"
agg_value_key = self.options.get("agg_value_key", "id") or "id"
group_key = self.options.get("group_key")
direction = self.options.get("direction", "vertical") or "vertical"
......@@ -420,11 +422,11 @@ class BarPlot(PlotKind):
match self.options.get("agg_func", "count"):
case "count":
agg_func = Count("id")
agg_func = Count(group_key)
case "sum":
agg_func = Sum(value_key)
agg_func = Sum(agg_value_key)
case "avg":
agg_func = Avg(value_key)
agg_func = Avg(agg_value_key)
case _:
raise ValueError("Invalid aggregation function")
......@@ -475,8 +477,8 @@ class BarPlot(PlotKind):
group_key = forms.ChoiceField(
label="Group by key", required=False, initial="id", choices=[]
)
value_key = forms.ChoiceField(
label="Value key", required=False, initial="id", choices=[]
agg_value_key = forms.ChoiceField(
label="Agg Value key", required=False, initial="id", choices=[]
)
agg_func = forms.ChoiceField(
label="Aggregation function",
......@@ -514,7 +516,7 @@ class BarPlot(PlotKind):
),
Div(Field("group_key"), css_class="col-12"),
Div(Field("agg_func"), css_class="col-6"),
Div(Field("value_key"), css_class="col-6"),
Div(Field("agg_value_key"), css_class="col-6"),
Div(
Div(
Div(Field("order_by"), css_class="col-6"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment