SciPost Code Repository

Skip to content
Snippets Groups Projects
Commit 21a64456 authored by George Katsikas's avatar George Katsikas :goat:
Browse files

make plotter an instance parameter of plot kind

parent f38c25f7
No related branches found
No related tags found
No related merge requests found
......@@ -50,17 +50,11 @@ class PlotKindSelectForm(forms.Form):
super().__init__(*args, **kwargs)
# If a plot kind is already selected, populate the form with its options
if plot_kind := PlotKind.from_name(self.data.get("plot_kind", None), self.data):
self.fields.update(plot_kind.Options.get_option_fields())
plot_kind_class_name = self.data.get("plot_kind", None)
if plot_kind_class := PlotKind.class_from_name(plot_kind_class_name):
self.fields.update(plot_kind_class.Options.get_option_fields())
self.kind = plot_kind
def clean(self):
cleaned_data = super().clean()
# Recreate plot kind with cleaned data
self.kind = PlotKind.from_name(self.data.get("plot_kind", None), cleaned_data)
return cleaned_data
self.kind_class = plot_kind_class
class GenericPlotOptionsForm(forms.Form):
......
......@@ -25,15 +25,18 @@ class PlotKind:
prefix = "plot_kind_"
pass
def __init__(self, options: "OptionDict" = {}):
def __init__(self, plotter: "ModelFieldPlotter", options: "OptionDict" = {}):
self.plotter = plotter
self.options = self.Options.parse_prefixed_options(options)
@classmethod
def from_name(cls, name: str, *args, **kwargs):
def class_from_name(cls, name: str):
from graphs.graphs import ALL_PLOT_KINDS
if cls_name := ALL_PLOT_KINDS.get(name, None):
return cls_name(*args, **kwargs)
return cls_name
return PlotKind
@classmethod
def get_name(cls) -> str:
......@@ -49,24 +52,24 @@ class PlotKind:
"""
return Figure(**kwargs)
def get_data(self, plotter: "ModelFieldPlotter") -> tuple[list[int], list[Any]]:
def get_data(self) -> tuple[list[int], list[Any]]:
"""
Obtain the values to plot from the queryset.
"""
qs = plotter.get_queryset()
qs = self.plotter.get_queryset()
y = qs.values_list("id", flat=True)
x = list(range(len(y)))
return x, y
def plot(self, plotter: "ModelFieldPlotter"):
def plot(self):
"""
Plot the data on a the figure.
"""
fig = self.get_figure()
ax = fig.add_subplot(111)
ax.set_title(f"{self.get_name()} plot of {plotter.model.__name__}")
ax.set_title(f"{self.get_name()} plot of {self.plotter.model.__name__}")
x, y = self.get_data(plotter)
x, y = self.get_data()
ax.plot(x, y)
return fig
......@@ -74,36 +77,36 @@ class PlotKind:
class TimelinePlot(PlotKind):
name = "timeline"
def plot(self, plotter: "ModelFieldPlotter"):
fig = super().plot(plotter)
def plot(self):
fig = super().plot()
ax = fig.get_axes()[0]
ax.set_xlabel(plotter.date_key)
ax.set_xlabel(self.plotter.date_key)
ax.set_ylabel(self.options.get("y_key", "id"))
return fig
def get_data(self, plotter: "ModelFieldPlotter", **kwargs):
def get_data(self):
y_key = self.options.get("y_key", "id") or "id"
# Filter the queryset to only include entries with a date and a y value
query_filters = Q(
**{
plotter.date_key + "__isnull": False,
self.plotter.date_key + "__isnull": False,
y_key + "__isnull": False,
}
)
# Filter the queryset according to the date limits if they are set
if x_lim_min := self.options.get("x_lim_min", None):
query_filters &= Q(**{plotter.date_key + "__gte": x_lim_min})
query_filters &= Q(**{self.plotter.date_key + "__gte": x_lim_min})
if x_lim_max := self.options.get("x_lim_max", None):
query_filters &= Q(**{plotter.date_key + "__lte": x_lim_max})
query_filters &= Q(**{self.plotter.date_key + "__lte": x_lim_max})
qs = plotter.get_queryset()
qs = self.plotter.get_queryset()
qs = qs.filter(query_filters)
qs = qs.order_by(plotter.date_key)
qs = qs.order_by(self.plotter.date_key)
x, y = zip(*qs.values_list(plotter.date_key, y_key))
x, y = zip(*qs.values_list(self.plotter.date_key, y_key))
return x, y
......@@ -158,7 +161,7 @@ class MapPlot(PlotKind):
cax0.tick_params(axis="x", length=1.5, direction="out", which="minor")
cax0.grid(False)
def plot(self, plotter: "ModelFieldPlotter"):
def plot(self):
from graphs.graphs import BASE_WORLD, OKLCH
from matplotlib.colors import LinearSegmentedColormap, LogNorm
......@@ -166,7 +169,7 @@ class MapPlot(PlotKind):
self.draw_colorbar(fig)
ax, cax, _ = fig.get_axes()
countries, count = self.get_data(plotter)
countries, count = self.get_data()
df_counts = pd.DataFrame({"ISO_A2_EH": countries, "count": count})
vmax = df_counts["count"].max()
BASE_WORLD.merge(df_counts, left_on="ISO_A2_EH", right_on="ISO_A2_EH").plot(
......@@ -188,16 +191,19 @@ class MapPlot(PlotKind):
return fig
def get_data(self, plotter: "ModelFieldPlotter", **kwargs):
qs = plotter.get_queryset()
def get_data(self):
"""
Return the a tuple of lists of countries and their counts.
"""
qs = self.plotter.get_queryset()
count_key = self.options.get("count_key", "id")
group_by_country_count = (
qs.filter(Q(**{plotter.country_key + "__isnull": False}))
.values(plotter.country_key)
qs.filter(Q(**{self.plotter.country_key + "__isnull": False}))
.values(self.plotter.country_key)
.annotate(count=Count(count_key))
)
countries, count = zip(
*group_by_country_count.values_list(plotter.country_key, "count")
*group_by_country_count.values_list(self.plotter.country_key, "count")
)
return countries, count
......@@ -79,7 +79,7 @@ class ModelFieldPlotter(ABC):
]
)
fig = kind.plot(plotter=self)
fig = kind.plot()
fig.suptitle(options.get("title", None))
return fig
......
......@@ -50,7 +50,10 @@ class PlotView(View):
cleaned_data = form.clean()
self.plotter = form.model_field_select_form.plotter
self.kind = form.plot_kind_select_form.kind
self.kind = form.plot_kind_select_form.kind_class(
options=form.plot_kind_select_form.cleaned_data,
plotter=self.plotter,
)
self.plot_options = {
"plot_kind": {},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment