SciPost Code Repository

Skip to content
Snippets Groups Projects
Commit 21a64456 authored by George Katsikas's avatar George Katsikas :goat:
Browse files

make plotter an instance parameter of plot kind

parent f38c25f7
No related branches found
No related tags found
No related merge requests found
...@@ -50,17 +50,11 @@ class PlotKindSelectForm(forms.Form): ...@@ -50,17 +50,11 @@ class PlotKindSelectForm(forms.Form):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# If a plot kind is already selected, populate the form with its options # If a plot kind is already selected, populate the form with its options
if plot_kind := PlotKind.from_name(self.data.get("plot_kind", None), self.data): plot_kind_class_name = self.data.get("plot_kind", None)
self.fields.update(plot_kind.Options.get_option_fields()) if plot_kind_class := PlotKind.class_from_name(plot_kind_class_name):
self.fields.update(plot_kind_class.Options.get_option_fields())
self.kind = plot_kind self.kind_class = plot_kind_class
def clean(self):
cleaned_data = super().clean()
# Recreate plot kind with cleaned data
self.kind = PlotKind.from_name(self.data.get("plot_kind", None), cleaned_data)
return cleaned_data
class GenericPlotOptionsForm(forms.Form): class GenericPlotOptionsForm(forms.Form):
......
...@@ -25,15 +25,18 @@ class PlotKind: ...@@ -25,15 +25,18 @@ class PlotKind:
prefix = "plot_kind_" prefix = "plot_kind_"
pass pass
def __init__(self, options: "OptionDict" = {}): def __init__(self, plotter: "ModelFieldPlotter", options: "OptionDict" = {}):
self.plotter = plotter
self.options = self.Options.parse_prefixed_options(options) self.options = self.Options.parse_prefixed_options(options)
@classmethod @classmethod
def from_name(cls, name: str, *args, **kwargs): def class_from_name(cls, name: str):
from graphs.graphs import ALL_PLOT_KINDS from graphs.graphs import ALL_PLOT_KINDS
if cls_name := ALL_PLOT_KINDS.get(name, None): if cls_name := ALL_PLOT_KINDS.get(name, None):
return cls_name(*args, **kwargs) return cls_name
return PlotKind
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> str:
...@@ -49,24 +52,24 @@ class PlotKind: ...@@ -49,24 +52,24 @@ class PlotKind:
""" """
return Figure(**kwargs) return Figure(**kwargs)
def get_data(self, plotter: "ModelFieldPlotter") -> tuple[list[int], list[Any]]: def get_data(self) -> tuple[list[int], list[Any]]:
""" """
Obtain the values to plot from the queryset. Obtain the values to plot from the queryset.
""" """
qs = plotter.get_queryset() qs = self.plotter.get_queryset()
y = qs.values_list("id", flat=True) y = qs.values_list("id", flat=True)
x = list(range(len(y))) x = list(range(len(y)))
return x, y return x, y
def plot(self, plotter: "ModelFieldPlotter"): def plot(self):
""" """
Plot the data on a the figure. Plot the data on a the figure.
""" """
fig = self.get_figure() fig = self.get_figure()
ax = fig.add_subplot(111) ax = fig.add_subplot(111)
ax.set_title(f"{self.get_name()} plot of {plotter.model.__name__}") ax.set_title(f"{self.get_name()} plot of {self.plotter.model.__name__}")
x, y = self.get_data(plotter) x, y = self.get_data()
ax.plot(x, y) ax.plot(x, y)
return fig return fig
...@@ -74,36 +77,36 @@ class PlotKind: ...@@ -74,36 +77,36 @@ class PlotKind:
class TimelinePlot(PlotKind): class TimelinePlot(PlotKind):
name = "timeline" name = "timeline"
def plot(self, plotter: "ModelFieldPlotter"): def plot(self):
fig = super().plot(plotter) fig = super().plot()
ax = fig.get_axes()[0] ax = fig.get_axes()[0]
ax.set_xlabel(plotter.date_key) ax.set_xlabel(self.plotter.date_key)
ax.set_ylabel(self.options.get("y_key", "id")) ax.set_ylabel(self.options.get("y_key", "id"))
return fig return fig
def get_data(self, plotter: "ModelFieldPlotter", **kwargs): def get_data(self):
y_key = self.options.get("y_key", "id") or "id" y_key = self.options.get("y_key", "id") or "id"
# Filter the queryset to only include entries with a date and a y value # Filter the queryset to only include entries with a date and a y value
query_filters = Q( query_filters = Q(
**{ **{
plotter.date_key + "__isnull": False, self.plotter.date_key + "__isnull": False,
y_key + "__isnull": False, y_key + "__isnull": False,
} }
) )
# Filter the queryset according to the date limits if they are set # Filter the queryset according to the date limits if they are set
if x_lim_min := self.options.get("x_lim_min", None): if x_lim_min := self.options.get("x_lim_min", None):
query_filters &= Q(**{plotter.date_key + "__gte": x_lim_min}) query_filters &= Q(**{self.plotter.date_key + "__gte": x_lim_min})
if x_lim_max := self.options.get("x_lim_max", None): if x_lim_max := self.options.get("x_lim_max", None):
query_filters &= Q(**{plotter.date_key + "__lte": x_lim_max}) query_filters &= Q(**{self.plotter.date_key + "__lte": x_lim_max})
qs = plotter.get_queryset() qs = self.plotter.get_queryset()
qs = qs.filter(query_filters) qs = qs.filter(query_filters)
qs = qs.order_by(plotter.date_key) qs = qs.order_by(self.plotter.date_key)
x, y = zip(*qs.values_list(plotter.date_key, y_key)) x, y = zip(*qs.values_list(self.plotter.date_key, y_key))
return x, y return x, y
...@@ -158,7 +161,7 @@ class MapPlot(PlotKind): ...@@ -158,7 +161,7 @@ class MapPlot(PlotKind):
cax0.tick_params(axis="x", length=1.5, direction="out", which="minor") cax0.tick_params(axis="x", length=1.5, direction="out", which="minor")
cax0.grid(False) cax0.grid(False)
def plot(self, plotter: "ModelFieldPlotter"): def plot(self):
from graphs.graphs import BASE_WORLD, OKLCH from graphs.graphs import BASE_WORLD, OKLCH
from matplotlib.colors import LinearSegmentedColormap, LogNorm from matplotlib.colors import LinearSegmentedColormap, LogNorm
...@@ -166,7 +169,7 @@ class MapPlot(PlotKind): ...@@ -166,7 +169,7 @@ class MapPlot(PlotKind):
self.draw_colorbar(fig) self.draw_colorbar(fig)
ax, cax, _ = fig.get_axes() ax, cax, _ = fig.get_axes()
countries, count = self.get_data(plotter) countries, count = self.get_data()
df_counts = pd.DataFrame({"ISO_A2_EH": countries, "count": count}) df_counts = pd.DataFrame({"ISO_A2_EH": countries, "count": count})
vmax = df_counts["count"].max() vmax = df_counts["count"].max()
BASE_WORLD.merge(df_counts, left_on="ISO_A2_EH", right_on="ISO_A2_EH").plot( BASE_WORLD.merge(df_counts, left_on="ISO_A2_EH", right_on="ISO_A2_EH").plot(
...@@ -188,16 +191,19 @@ class MapPlot(PlotKind): ...@@ -188,16 +191,19 @@ class MapPlot(PlotKind):
return fig return fig
def get_data(self, plotter: "ModelFieldPlotter", **kwargs): def get_data(self):
qs = plotter.get_queryset() """
Return the a tuple of lists of countries and their counts.
"""
qs = self.plotter.get_queryset()
count_key = self.options.get("count_key", "id") count_key = self.options.get("count_key", "id")
group_by_country_count = ( group_by_country_count = (
qs.filter(Q(**{plotter.country_key + "__isnull": False})) qs.filter(Q(**{self.plotter.country_key + "__isnull": False}))
.values(plotter.country_key) .values(self.plotter.country_key)
.annotate(count=Count(count_key)) .annotate(count=Count(count_key))
) )
countries, count = zip( countries, count = zip(
*group_by_country_count.values_list(plotter.country_key, "count") *group_by_country_count.values_list(self.plotter.country_key, "count")
) )
return countries, count return countries, count
...@@ -79,7 +79,7 @@ class ModelFieldPlotter(ABC): ...@@ -79,7 +79,7 @@ class ModelFieldPlotter(ABC):
] ]
) )
fig = kind.plot(plotter=self) fig = kind.plot()
fig.suptitle(options.get("title", None)) fig.suptitle(options.get("title", None))
return fig return fig
......
...@@ -50,7 +50,10 @@ class PlotView(View): ...@@ -50,7 +50,10 @@ class PlotView(View):
cleaned_data = form.clean() cleaned_data = form.clean()
self.plotter = form.model_field_select_form.plotter self.plotter = form.model_field_select_form.plotter
self.kind = form.plot_kind_select_form.kind self.kind = form.plot_kind_select_form.kind_class(
options=form.plot_kind_select_form.cleaned_data,
plotter=self.plotter,
)
self.plot_options = { self.plot_options = {
"plot_kind": {}, "plot_kind": {},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment