diff --git a/scipost_django/graphs/graphs/plotkind.py b/scipost_django/graphs/graphs/plotkind.py index 50fe5b64b4164b38cf9827978197c3c0e600c329..e7f538c1045f49133399f156c39abd877684847d 100644 --- a/scipost_django/graphs/graphs/plotkind.py +++ b/scipost_django/graphs/graphs/plotkind.py @@ -2,7 +2,7 @@ __copyright__ = "Copyright © Stichting SciPost (SciPost Foundation)" __license__ = "AGPL v3" from django import forms -from django.db.models import Q, Count +from django.db.models import Q, Avg, Count, Sum from matplotlib.figure import Figure import pandas as pd @@ -20,7 +20,7 @@ class PlotKind: Generic class for a plot kind. """ - name: str + name: str = "Default" class Options(BaseOptions): prefix = "plot_kind_" @@ -190,8 +190,15 @@ class MapPlot(PlotKind): return fig def draw_colorbar(self, fig: Figure, **kwargs): + color_bar_title = ( + self.options.get("agg_func", "count").capitalize() + + " of " + + self.options.get("agg_key", "id") + + " per country" + ) + cax = fig.add_axes([0.385, 0.2, 0.45, 0.02]) - cax.set_title("Counts", fontsize="small") + cax.set_title(color_bar_title, fontsize="small") cax.tick_params(axis="x", length=2, direction="out", which="major") cax.tick_params(axis="x", length=1.5, direction="out", which="minor") cax.grid(False) @@ -215,12 +222,14 @@ class MapPlot(PlotKind): self.draw_colorbar(fig) ax, cax, _ = fig.get_axes() - countries, count = self.get_data() - df_counts = pd.DataFrame({"ISO_A2_EH": countries, "count": count}) - vmax = df_counts["count"].max() - BASE_WORLD.merge(df_counts, left_on="ISO_A2_EH", right_on="ISO_A2_EH").plot( + ax.set_title(f"World map distribution of {self.plotter.model.__name__}") + + countries, agg = self.get_data() + df_agg = pd.DataFrame({"ISO_A2_EH": countries, "agg": agg}) + vmax = df_agg["agg"].max() + BASE_WORLD.merge(df_agg, left_on="ISO_A2_EH", right_on="ISO_A2_EH").plot( ax=ax, - column="count", + column="agg", legend=True, legend_kwds={"orientation": "horizontal"}, cmap=LinearSegmentedColormap.from_list("custom", OKLCH), @@ -250,16 +259,75 @@ class MapPlot(PlotKind): Return the a tuple of lists of countries and their counts. """ qs = self.plotter.get_queryset() + + value_key = self.options.get("agg_key", "id") or "id" + + group_by_country_agg = qs.filter( + Q(**{self.plotter.country_key + "__isnull": False}) + ).values(self.plotter.country_key) + + match self.options.get("agg_func", "count"): + case "count": + agg_func = Count + case "sum": + agg_func = Sum + case "avg": + agg_func = Avg + case _: + raise ValueError("Invalid aggregation function") + + group_by_country_agg = group_by_country_agg.annotate(agg=agg_func(value_key)) + + countries, agg = zip( + *group_by_country_agg.values_list(self.plotter.country_key, "agg") + ) + + # Convert the aggregated data to floats + agg = [float(a) for a in agg] + + return countries, agg + + class Options(BaseOptions): prefix = "map_plot_" - count_key = self.options.get("count_key", "id") - group_by_country_count = ( - qs.filter(Q(**{self.plotter.country_key + "__isnull": False})) - .values(self.plotter.country_key) - .annotate(count=Count(count_key)) - prefix = "map_plot_" + agg_func = forms.ChoiceField( + label="Aggregation function", + choices=[ + ("count", "Count"), + ("sum", "Sum"), + ("avg", "Average"), + ], + required=False, + initial="count", ) + agg_key = forms.CharField(label="Aggregation key", initial="id", required=False) - countries, count = zip( - *group_by_country_count.values_list(self.plotter.country_key, "count") + @classmethod + def get_plot_options_form_layout_row_content(cls): + layout = Layout( + Div(Field("agg_func"), css_class="col-6"), + Div(Field("agg_key"), css_class="col-6"), ) - return countries, count + + # Prefix every field in the layout with the prefix + def prefix_field(field): + """ + Recursively prefix the fields in a layout. + Return type is irrelevant, as it modifies the argument directly. + """ + contained_fields = getattr(field, "fields", None) + if contained_fields is None: + return + + # If the crispy field is a Field type with a single string identifier, prefix it + if ( + isinstance(field, Field) + and len(contained_fields) == 1 + and isinstance(field_key := contained_fields[0], str) + ): + field.fields = [cls.Options.prefix + field_key] + else: + return [prefix_field(f) for f in contained_fields] + + prefix_field(layout) + + return layout diff --git a/scipost_django/graphs/graphs/plotter.py b/scipost_django/graphs/graphs/plotter.py index aeed2979a74903fdc7e98ef2902de1594692f432..67b368a5d0f2fcfda7eb6c865c5e8e8c032de36c 100644 --- a/scipost_django/graphs/graphs/plotter.py +++ b/scipost_django/graphs/graphs/plotter.py @@ -77,7 +77,7 @@ class ModelFieldPlotter(ABC): plt.style.use( [ ALL_MPL_THEMES.get("_base", ""), - ALL_MPL_THEMES.get(options.get("theme", None), "default"), + ALL_MPL_THEMES.get(options.get("theme", None), "light"), ] )