diff --git a/scipost_django/graphs/forms.py b/scipost_django/graphs/forms.py
index a185911f2f96ab588da8419c9a79f24b226eb20a..4b6950b1e85c14e036d23122b6cab0ec671aed77 100644
--- a/scipost_django/graphs/forms.py
+++ b/scipost_django/graphs/forms.py
@@ -50,17 +50,11 @@ class PlotKindSelectForm(forms.Form):
         super().__init__(*args, **kwargs)
 
         # If a plot kind is already selected, populate the form with its options
-        if plot_kind := PlotKind.from_name(self.data.get("plot_kind", None), self.data):
-            self.fields.update(plot_kind.Options.get_option_fields())
+        plot_kind_class_name = self.data.get("plot_kind", None)
+        if plot_kind_class := PlotKind.class_from_name(plot_kind_class_name):
+            self.fields.update(plot_kind_class.Options.get_option_fields())
 
-        self.kind = plot_kind
-
-    def clean(self):
-        cleaned_data = super().clean()
-
-        # Recreate plot kind with cleaned data
-        self.kind = PlotKind.from_name(self.data.get("plot_kind", None), cleaned_data)
-        return cleaned_data
+        self.kind_class = plot_kind_class
 
 
 class GenericPlotOptionsForm(forms.Form):
diff --git a/scipost_django/graphs/graphs/plotkind.py b/scipost_django/graphs/graphs/plotkind.py
index 02b3817ef0247fe73d72ae7390280a82fda0f515..babf740f2e376046368d5e8059d2c48c8acb42f8 100644
--- a/scipost_django/graphs/graphs/plotkind.py
+++ b/scipost_django/graphs/graphs/plotkind.py
@@ -25,15 +25,18 @@ class PlotKind:
         prefix = "plot_kind_"
         pass
 
-    def __init__(self, options: "OptionDict" = {}):
+    def __init__(self, plotter: "ModelFieldPlotter", options: "OptionDict" = {}):
+        self.plotter = plotter
         self.options = self.Options.parse_prefixed_options(options)
 
     @classmethod
-    def from_name(cls, name: str, *args, **kwargs):
+    def class_from_name(cls, name: str):
         from graphs.graphs import ALL_PLOT_KINDS
 
         if cls_name := ALL_PLOT_KINDS.get(name, None):
-            return cls_name(*args, **kwargs)
+            return cls_name
+
+        return PlotKind
 
     @classmethod
     def get_name(cls) -> str:
@@ -49,24 +52,24 @@ class PlotKind:
         """
         return Figure(**kwargs)
 
-    def get_data(self, plotter: "ModelFieldPlotter") -> tuple[list[int], list[Any]]:
+    def get_data(self) -> tuple[list[int], list[Any]]:
         """
         Obtain the values to plot from the queryset.
         """
-        qs = plotter.get_queryset()
+        qs = self.plotter.get_queryset()
         y = qs.values_list("id", flat=True)
         x = list(range(len(y)))
         return x, y
 
-    def plot(self, plotter: "ModelFieldPlotter"):
+    def plot(self):
         """
         Plot the data on a the figure.
         """
         fig = self.get_figure()
         ax = fig.add_subplot(111)
-        ax.set_title(f"{self.get_name()} plot of {plotter.model.__name__}")
+        ax.set_title(f"{self.get_name()} plot of {self.plotter.model.__name__}")
 
-        x, y = self.get_data(plotter)
+        x, y = self.get_data()
         ax.plot(x, y)
         return fig
 
@@ -74,36 +77,36 @@ class PlotKind:
 class TimelinePlot(PlotKind):
     name = "timeline"
 
-    def plot(self, plotter: "ModelFieldPlotter"):
-        fig = super().plot(plotter)
+    def plot(self):
+        fig = super().plot()
         ax = fig.get_axes()[0]
 
-        ax.set_xlabel(plotter.date_key)
+        ax.set_xlabel(self.plotter.date_key)
         ax.set_ylabel(self.options.get("y_key", "id"))
 
         return fig
 
-    def get_data(self, plotter: "ModelFieldPlotter", **kwargs):
+    def get_data(self):
         y_key = self.options.get("y_key", "id") or "id"
 
         # Filter the queryset to only include entries with a date and a y value
         query_filters = Q(
             **{
-                plotter.date_key + "__isnull": False,
+                self.plotter.date_key + "__isnull": False,
                 y_key + "__isnull": False,
             }
         )
         # Filter the queryset according to the date limits if they are set
         if x_lim_min := self.options.get("x_lim_min", None):
-            query_filters &= Q(**{plotter.date_key + "__gte": x_lim_min})
+            query_filters &= Q(**{self.plotter.date_key + "__gte": x_lim_min})
         if x_lim_max := self.options.get("x_lim_max", None):
-            query_filters &= Q(**{plotter.date_key + "__lte": x_lim_max})
+            query_filters &= Q(**{self.plotter.date_key + "__lte": x_lim_max})
 
-        qs = plotter.get_queryset()
+        qs = self.plotter.get_queryset()
         qs = qs.filter(query_filters)
-        qs = qs.order_by(plotter.date_key)
+        qs = qs.order_by(self.plotter.date_key)
 
-        x, y = zip(*qs.values_list(plotter.date_key, y_key))
+        x, y = zip(*qs.values_list(self.plotter.date_key, y_key))
 
         return x, y
 
@@ -158,7 +161,7 @@ class MapPlot(PlotKind):
         cax0.tick_params(axis="x", length=1.5, direction="out", which="minor")
         cax0.grid(False)
 
-    def plot(self, plotter: "ModelFieldPlotter"):
+    def plot(self):
         from graphs.graphs import BASE_WORLD, OKLCH
         from matplotlib.colors import LinearSegmentedColormap, LogNorm
 
@@ -166,7 +169,7 @@ class MapPlot(PlotKind):
         self.draw_colorbar(fig)
         ax, cax, _ = fig.get_axes()
 
-        countries, count = self.get_data(plotter)
+        countries, count = self.get_data()
         df_counts = pd.DataFrame({"ISO_A2_EH": countries, "count": count})
         vmax = df_counts["count"].max()
         BASE_WORLD.merge(df_counts, left_on="ISO_A2_EH", right_on="ISO_A2_EH").plot(
@@ -188,16 +191,19 @@ class MapPlot(PlotKind):
 
         return fig
 
-    def get_data(self, plotter: "ModelFieldPlotter", **kwargs):
-        qs = plotter.get_queryset()
+    def get_data(self):
+        """
+        Return the a tuple of lists of countries and their counts.
+        """
+        qs = self.plotter.get_queryset()
         count_key = self.options.get("count_key", "id")
         group_by_country_count = (
-            qs.filter(Q(**{plotter.country_key + "__isnull": False}))
-            .values(plotter.country_key)
+            qs.filter(Q(**{self.plotter.country_key + "__isnull": False}))
+            .values(self.plotter.country_key)
             .annotate(count=Count(count_key))
         )
 
         countries, count = zip(
-            *group_by_country_count.values_list(plotter.country_key, "count")
+            *group_by_country_count.values_list(self.plotter.country_key, "count")
         )
         return countries, count
diff --git a/scipost_django/graphs/graphs/plotter.py b/scipost_django/graphs/graphs/plotter.py
index b8f992b1154ba51197587ab808ac5969a665e5de..d7d37e110e01cb62503a37a59b817d1fede670b7 100644
--- a/scipost_django/graphs/graphs/plotter.py
+++ b/scipost_django/graphs/graphs/plotter.py
@@ -79,7 +79,7 @@ class ModelFieldPlotter(ABC):
             ]
         )
 
-        fig = kind.plot(plotter=self)
+        fig = kind.plot()
         fig.suptitle(options.get("title", None))
 
         return fig
diff --git a/scipost_django/graphs/views.py b/scipost_django/graphs/views.py
index dd50620b67ab13e1c5961f1798fb6de3c06e7b4c..aa401b6e1060b83ff82a188f190c23e85e896690 100644
--- a/scipost_django/graphs/views.py
+++ b/scipost_django/graphs/views.py
@@ -50,7 +50,10 @@ class PlotView(View):
         cleaned_data = form.clean()
 
         self.plotter = form.model_field_select_form.plotter
-        self.kind = form.plot_kind_select_form.kind
+        self.kind = form.plot_kind_select_form.kind_class(
+            options=form.plot_kind_select_form.cleaned_data,
+            plotter=self.plotter,
+        )
 
         self.plot_options = {
             "plot_kind": {},