SciPost Code Repository

Skip to content
Snippets Groups Projects
Commit 1dcf6ca8 authored by George Katsikas's avatar George Katsikas :goat:
Browse files

refactor(graphs): :recycle:️ differentiate aggragated value key

Renames value_key to agg_value_key for the plot kinds where it is aggregated
parent 2d88f898
No related branches found
No related tags found
No related merge requests found
...@@ -137,7 +137,8 @@ class PlotOptionsForm(InitialCoalescedForm): ...@@ -137,7 +137,8 @@ class PlotOptionsForm(InitialCoalescedForm):
""" """
FIELD_ADMISSIBLE_TYPES: dict[str, list[str]] = { FIELD_ADMISSIBLE_TYPES: dict[str, list[str]] = {
"value_key": ["int", "float"], "value_key": ["int", "float", "date"],
"agg_value_key": ["int", "float"],
"timeline_key": ["date", "datetime"], "timeline_key": ["date", "datetime"],
"group_key": ["str", "int", "country"], "group_key": ["str", "int", "country"],
"country_key": ["country"], "country_key": ["country"],
......
...@@ -231,8 +231,8 @@ class MapPlot(PlotKind): ...@@ -231,8 +231,8 @@ class MapPlot(PlotKind):
from graphs.graphs import BASE_WORLD, OKLCH from graphs.graphs import BASE_WORLD, OKLCH
agg_func = self.options.get("agg_func", "count") agg_func = self.options.get("agg_func", "count")
value_key_display = self.plotter.get_model_field_display( agg_value_key_display = self.plotter.get_model_field_display(
self.options.get("value_key") self.options.get("agg_value_key")
) )
country_key_display = ( country_key_display = (
self.plotter.get_model_field_display(self.options.get("country_key")) self.plotter.get_model_field_display(self.options.get("country_key"))
...@@ -241,12 +241,12 @@ class MapPlot(PlotKind): ...@@ -241,12 +241,12 @@ class MapPlot(PlotKind):
if agg_func == "count": if agg_func == "count":
plot_title = "{model} per {country}" plot_title = "{model} per {country}"
else: else:
plot_title = "{agg_func} of {model}' {value_key_display} per {country}" plot_title = "{agg_func} of {model}' {agg_value_key_display} per {country}"
plot_title = plot_title.format( plot_title = plot_title.format(
model=self.plotter.model._meta.verbose_name_plural, model=self.plotter.model._meta.verbose_name_plural,
agg_func=agg_func, agg_func=agg_func,
value_key_display=value_key_display, agg_value_key_display=agg_value_key_display,
country=country_key_display, country=country_key_display,
).capitalize() ).capitalize()
color_plot_title, _ = plot_title.split(" per ") color_plot_title, _ = plot_title.split(" per ")
...@@ -300,7 +300,7 @@ class MapPlot(PlotKind): ...@@ -300,7 +300,7 @@ class MapPlot(PlotKind):
""" """
qs = self.plotter.get_queryset() qs = self.plotter.get_queryset()
value_key = self.options.get("value_key", "id") or "id" agg_value_key = self.options.get("agg_value_key", "id") or "id"
country_key = self.options.get("country_key") country_key = self.options.get("country_key")
if country_key is None: if country_key is None:
...@@ -310,9 +310,9 @@ class MapPlot(PlotKind): ...@@ -310,9 +310,9 @@ class MapPlot(PlotKind):
case "count": case "count":
agg_func = Count("id") agg_func = Count("id")
case "sum": case "sum":
agg_func = Sum(value_key) agg_func = Sum(agg_value_key)
case "avg": case "avg":
agg_func = Avg(value_key) agg_func = Avg(agg_value_key)
case _: case _:
raise ValueError("Invalid aggregation function") raise ValueError("Invalid aggregation function")
...@@ -341,7 +341,9 @@ class MapPlot(PlotKind): ...@@ -341,7 +341,9 @@ class MapPlot(PlotKind):
required=False, required=False,
initial="count", initial="count",
) )
value_key = forms.ChoiceField(label="Value key", required=False, choices=[]) agg_value_key = forms.ChoiceField(
label="Agg Value key", required=False, choices=[]
)
country_key = forms.ChoiceField(label="Country key", required=False, choices=[]) country_key = forms.ChoiceField(label="Country key", required=False, choices=[])
@classmethod @classmethod
...@@ -349,7 +351,7 @@ class MapPlot(PlotKind): ...@@ -349,7 +351,7 @@ class MapPlot(PlotKind):
return Layout( return Layout(
Div(Field("country_key"), css_class="col-12"), Div(Field("country_key"), css_class="col-12"),
Div(Field("agg_func"), css_class="col-6"), Div(Field("agg_func"), css_class="col-6"),
Div(Field("value_key"), css_class="col-6"), Div(Field("agg_value_key"), css_class="col-6"),
) )
...@@ -374,17 +376,17 @@ class BarPlot(PlotKind): ...@@ -374,17 +376,17 @@ class BarPlot(PlotKind):
): ):
ax.set(**{f"{group_label_axis}label": group_key_label.capitalize()}) ax.set(**{f"{group_label_axis}label": group_key_label.capitalize()})
if value_key_name := self.plotter.get_model_field_display( if agg_value_key_name := self.plotter.get_model_field_display(
self.options.get("value_key") self.options.get("agg_value_key")
): ):
value_key_label = f"{agg_func} of {value_key_name}" agg_value_key_label = f"{agg_func} of {agg_value_key_name}"
if agg_func == "count": if agg_func == "count":
# Simplify label and set locator to integer # Simplify label and set locator to integer
value_key_label = "Count" agg_value_key_label = "Count"
axis = getattr(ax, f"{value_label_axis}axis") axis = getattr(ax, f"{value_label_axis}axis")
axis.get_major_locator().set_params(integer=True) axis.get_major_locator().set_params(integer=True)
ax.set(**{f"{value_label_axis}label": value_key_label.capitalize()}) ax.set(**{f"{value_label_axis}label": agg_value_key_label.capitalize()})
try: try:
groups, vals = self.get_data() groups, vals = self.get_data()
...@@ -409,7 +411,7 @@ class BarPlot(PlotKind): ...@@ -409,7 +411,7 @@ class BarPlot(PlotKind):
return fig return fig
def get_data(self): def get_data(self):
value_key = self.options.get("value_key", "id") or "id" agg_value_key = self.options.get("agg_value_key", "id") or "id"
group_key = self.options.get("group_key") group_key = self.options.get("group_key")
direction = self.options.get("direction", "vertical") or "vertical" direction = self.options.get("direction", "vertical") or "vertical"
...@@ -420,11 +422,11 @@ class BarPlot(PlotKind): ...@@ -420,11 +422,11 @@ class BarPlot(PlotKind):
match self.options.get("agg_func", "count"): match self.options.get("agg_func", "count"):
case "count": case "count":
agg_func = Count("id") agg_func = Count(group_key)
case "sum": case "sum":
agg_func = Sum(value_key) agg_func = Sum(agg_value_key)
case "avg": case "avg":
agg_func = Avg(value_key) agg_func = Avg(agg_value_key)
case _: case _:
raise ValueError("Invalid aggregation function") raise ValueError("Invalid aggregation function")
...@@ -475,8 +477,8 @@ class BarPlot(PlotKind): ...@@ -475,8 +477,8 @@ class BarPlot(PlotKind):
group_key = forms.ChoiceField( group_key = forms.ChoiceField(
label="Group by key", required=False, initial="id", choices=[] label="Group by key", required=False, initial="id", choices=[]
) )
value_key = forms.ChoiceField( agg_value_key = forms.ChoiceField(
label="Value key", required=False, initial="id", choices=[] label="Agg Value key", required=False, initial="id", choices=[]
) )
agg_func = forms.ChoiceField( agg_func = forms.ChoiceField(
label="Aggregation function", label="Aggregation function",
...@@ -514,7 +516,7 @@ class BarPlot(PlotKind): ...@@ -514,7 +516,7 @@ class BarPlot(PlotKind):
), ),
Div(Field("group_key"), css_class="col-12"), Div(Field("group_key"), css_class="col-12"),
Div(Field("agg_func"), css_class="col-6"), Div(Field("agg_func"), css_class="col-6"),
Div(Field("value_key"), css_class="col-6"), Div(Field("agg_value_key"), css_class="col-6"),
Div( Div(
Div( Div(
Div(Field("order_by"), css_class="col-6"), Div(Field("order_by"), css_class="col-6"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment