From f2fea4626a1b51ae4cc1892c1e260e95972a137d Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Wed, 26 Apr 2023 19:06:28 +0200
Subject: [PATCH 01/24] began to set PLN(Y,O,cov) instead of PLN() (as stat
 models)

---
 pyPLNmodels/models.py | 76 ++++++++++++++++++++++++++-----------------
 test.py               | 15 ++++++---
 2 files changed, 58 insertions(+), 33 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index f8cbdf9d..b3aceb30 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -3,6 +3,8 @@ from abc import ABC, abstractmethod
 import pickle
 import warnings
 import os
+from functools import singledispatchmethod
+from multipledispatch import dispatch
 
 import pandas as pd
 import torch
@@ -68,17 +70,17 @@ class _PLN(ABC):
     _latent_var: torch.Tensor
     _latent_mean: torch.Tensor
 
-    def __init__(self):
+    def __init__(self, counts, covariates, offsets, offsets_formula):
         """
         Simple initialization method.
         """
-        self._fitted = False
-        self.plotargs = PLNPlotArgs(self.WINDOW)
 
-    def format_model_param(self, counts, covariates, offsets, offsets_formula):
         self._counts, self._covariates, self._offsets = format_model_param(
             counts, covariates, offsets, offsets_formula
         )
+        check_data_shape(self._counts, self._covariates, self._offsets)
+        self._fitted = False
+        self.plotargs = PLNPlotArgs(self.WINDOW)
 
     @property
     def nb_iteration_done(self):
@@ -140,16 +142,12 @@ class _PLN(ABC):
 
     def fit(
         self,
-        counts,
-        covariates=None,
-        offsets=None,
         nb_max_iteration=50000,
         lr=0.01,
         class_optimizer=torch.optim.Rprop,
         tol=1e-6,
         do_smart_init=True,
         verbose=False,
-        offsets_formula="logsum",
         keep_going=False,
     ):
         """
@@ -169,8 +167,6 @@ class _PLN(ABC):
         self.beginnning_time = time.time()
 
         if keep_going is False:
-            self.format_model_param(counts, covariates, offsets, offsets_formula)
-            check_data_shape(self._counts, self._covariates, self._offsets)
             self.init_parameters(do_smart_init)
         if self._fitted is True and keep_going is True:
             self.beginnning_time -= self.plotargs.running_times[-1]
@@ -487,6 +483,20 @@ class PLN(_PLN):
     NAME = "PLN"
     coef: torch.Tensor
 
+    @singledispatchmethod
+    def __init__(self, counts, covariates=None, offsets=None, offsets_formula="logsum"):
+        super().__init__(counts, covariates, offsets, offsets_formula)
+
+    @__init__.register(str)
+    @__init__.register(str)
+    def _(self, path_of_directory: str, other: str):
+        print("file")
+
+    @__init__.register(pd.DataFrame)
+    @__init__.register(str)
+    def _(self, formula: str, data: pd.DataFrame):
+        print("formula")
+
     @property
     def description(self):
         return "full covariance model."
@@ -590,13 +600,30 @@ class PLN(_PLN):
 
 
 class PLNPCA:
-    def __init__(self, ranks):
+    def __init__(
+        self,
+        counts,
+        covariates=None,
+        offsets=None,
+        offsets_formula="logsum",
+        ranks=range(1, 5),
+    ):
+        self._counts, self._covariates, self._offsets = format_model_param(
+            counts, covariates, offsets, offsets_formula
+        )
+        check_data_shape(self._counts, self._covariates, self._offsets)
+        self._fitted = False
+        self.init_models(ranks)
+
+    def init_models(self, ranks):
         if isinstance(ranks, (list, np.ndarray)):
             self.ranks = ranks
             self.dict_models = {}
             for rank in ranks:
                 if isinstance(rank, (int, np.int64)):
-                    self.dict_models[rank] = _PLNPCA(rank)
+                    self.dict_models[rank] = _PLNPCA(
+                        rank, self._counts, self._covariates, self._offsets
+                    )
                 else:
                     raise TypeError(
                         "Please instantiate with either a list\
@@ -626,34 +653,23 @@ class PLNPCA:
     ## only in PLNPCA, then we don't do it for each _PLNPCA but then PLN is not doing it.
     def fit(
         self,
-        counts,
-        covariates=None,
-        offsets=None,
         nb_max_iteration=100000,
         lr=0.01,
         class_optimizer=torch.optim.Rprop,
         tol=1e-6,
         do_smart_init=True,
         verbose=False,
-        offsets_formula="logsum",
         keep_going=False,
     ):
         self.print_beginning_message()
-        counts, _, offsets = format_model_param(
-            counts, covariates, offsets, offsets_formula
-        )
         for pca in self.dict_models.values():
             pca.fit(
-                counts,
-                covariates,
-                offsets,
                 nb_max_iteration,
                 lr,
                 class_optimizer,
                 tol,
                 do_smart_init,
                 verbose,
-                None,
                 keep_going,
             )
         self.print_ending_message()
@@ -784,18 +800,20 @@ class _PLNPCA(_PLN):
     NAME = "PLNPCA"
     _components: torch.Tensor
 
-    def __init__(self, rank):
-        super().__init__()
+    def __init__(self, rank, counts, covariates, offsets):
         self._rank = rank
+        self._counts = counts
+        self._covariates = covariates
+        self._offsets = offsets
 
-    def init_parameters(self, do_smart_init):
-        if self.dim < self._rank:
-            warning_string = f"\nThe requested rank of approximation {self._rank} \
+        if self.dim < self.rank:
+            warning_string = f"\nThe requested rank of approximation {self.rank} \
                 is greater than the number of variables {self.dim}. \
                 Setting rank to {self.dim}"
             warnings.warn(warning_string)
             self._rank = self.dim
-        super().init_parameters(do_smart_init)
+        self._fitted = False
+        self.plotargs = PLNPlotArgs(self.WINDOW)
 
     @property
     def model_path(self):
diff --git a/test.py b/test.py
index 900e42df..16c403ea 100644
--- a/test.py
+++ b/test.py
@@ -2,6 +2,8 @@ from pyPLNmodels.models import PLNPCA, _PLNPCA, PLN
 from pyPLNmodels import get_real_count_data, get_simulated_count_data
 
 import os
+import pandas as pd
+import numpy as np
 
 os.chdir("./pyPLNmodels/")
 
@@ -11,12 +13,17 @@ covariates = None
 offsets = None
 # counts, covariates, offsets = get_simulated_count_data(seed = 0)
 
-pca = PLNPCA([3, 4])
+# pca = PLNPCA(counts, covariates, offsets,ranks = [3, 4])
 
-pca.fit(counts, covariates, offsets, tol=0.1)
-print(pca)
+# pca.fit()
+# print(pca)
 
-# pln = PLN()
+# pln = PLN(counts, covariates, offsets)
+# pln = PLN("test",4.)
+a = pd.DataFrame(data=np.zeros((10, 10)))
+pln = PLN("test", "test")
+# pln.fit()
+# print(pln)
 # pcamodel = pca.best_model()
 # pcamodel.save()
 # model = PLNPCA([4])[4]
-- 
GitLab


From 062971a29ac7371ece903995049cf6b00965d9f7 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Wed, 26 Apr 2023 19:47:12 +0200
Subject: [PATCH 02/24] write PLN(Y,O,cov) for example instead of PLN(). I
 overload __init__ to give a str and a pd.DataFrame.

---
 pyPLNmodels/models.py | 26 +++++++++++---------------
 test.py               |  6 ++----
 2 files changed, 13 insertions(+), 19 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index b3aceb30..d3f8e1c7 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -70,7 +70,8 @@ class _PLN(ABC):
     _latent_var: torch.Tensor
     _latent_mean: torch.Tensor
 
-    def __init__(self, counts, covariates, offsets, offsets_formula):
+    @singledispatchmethod
+    def __init__(self, counts, covariates=None, offsets=None, offsets_formula="logsum"):
         """
         Simple initialization method.
         """
@@ -82,6 +83,10 @@ class _PLN(ABC):
         self._fitted = False
         self.plotargs = PLNPlotArgs(self.WINDOW)
 
+    @__init__.register(str)
+    def _(self, formula: str, data: pd.DataFrame):
+        print("formula")
+
     @property
     def nb_iteration_done(self):
         return len(self.plotargs.elbos_list)
@@ -483,20 +488,6 @@ class PLN(_PLN):
     NAME = "PLN"
     coef: torch.Tensor
 
-    @singledispatchmethod
-    def __init__(self, counts, covariates=None, offsets=None, offsets_formula="logsum"):
-        super().__init__(counts, covariates, offsets, offsets_formula)
-
-    @__init__.register(str)
-    @__init__.register(str)
-    def _(self, path_of_directory: str, other: str):
-        print("file")
-
-    @__init__.register(pd.DataFrame)
-    @__init__.register(str)
-    def _(self, formula: str, data: pd.DataFrame):
-        print("formula")
-
     @property
     def description(self):
         return "full covariance model."
@@ -600,6 +591,7 @@ class PLN(_PLN):
 
 
 class PLNPCA:
+    @singledispatchmethod
     def __init__(
         self,
         counts,
@@ -615,6 +607,10 @@ class PLNPCA:
         self._fitted = False
         self.init_models(ranks)
 
+    @__init__.register(str)
+    def _(self, formula: str, data: pd.DataFrame):
+        print("formula")
+
     def init_models(self, ranks):
         if isinstance(ranks, (list, np.ndarray)):
             self.ranks = ranks
diff --git a/test.py b/test.py
index 16c403ea..2dc78415 100644
--- a/test.py
+++ b/test.py
@@ -13,15 +13,13 @@ covariates = None
 offsets = None
 # counts, covariates, offsets = get_simulated_count_data(seed = 0)
 
-# pca = PLNPCA(counts, covariates, offsets,ranks = [3, 4])
+pca = PLNPCA(counts, covariates, offsets, ranks=[3, 4])
+pca.fit()
 
 # pca.fit()
 # print(pca)
 
 # pln = PLN(counts, covariates, offsets)
-# pln = PLN("test",4.)
-a = pd.DataFrame(data=np.zeros((10, 10)))
-pln = PLN("test", "test")
 # pln.fit()
 # print(pln)
 # pcamodel = pca.best_model()
-- 
GitLab


From 7a23ab495a5d2d743b8f128b31aaad35766ad53d Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 27 Apr 2023 13:24:54 +0200
Subject: [PATCH 03/24] began to add init with formula

---
 pyPLNmodels/models.py | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index d3f8e1c7..5c853290 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -12,6 +12,7 @@ import numpy as np
 import seaborn as sns
 import matplotlib.pyplot as plt
 from sklearn.decomposition import PCA
+from patsy import dmatrices
 
 
 from ._closed_forms import (
@@ -85,7 +86,10 @@ class _PLN(ABC):
 
     @__init__.register(str)
     def _(self, formula: str, data: pd.DataFrame):
-        print("formula")
+        dmatrix = dmatrices(formula, data=data)
+        self._counts = dmatrix[0]
+        self._covariates = dmatrix[1]
+        offsets = None
 
     @property
     def nb_iteration_done(self):
-- 
GitLab


From 36f905979ba267c343ccce311a136b3886b1017f Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 27 Apr 2023 15:55:51 +0200
Subject: [PATCH 04/24] continue to implement formula like

---
 pyPLNmodels/models.py | 9 +++++----
 test.py               | 9 +++++----
 2 files changed, 10 insertions(+), 8 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 5c853290..68385eec 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -4,7 +4,6 @@ import pickle
 import warnings
 import os
 from functools import singledispatchmethod
-from multipledispatch import dispatch
 
 import pandas as pd
 import torch
@@ -83,13 +82,15 @@ class _PLN(ABC):
         check_data_shape(self._counts, self._covariates, self._offsets)
         self._fitted = False
         self.plotargs = PLNPlotArgs(self.WINDOW)
+        print("normal init")
 
     @__init__.register(str)
-    def _(self, formula: str, data: pd.DataFrame):
+    def _(self, formula: str, data: pd.DataFrame, offsets_formula="logsum"):
         dmatrix = dmatrices(formula, data=data)
-        self._counts = dmatrix[0]
-        self._covariates = dmatrix[1]
+        counts = dmatrix[0]
+        covariates = dmatrix[1]
         offsets = None
+        super().__init__(counts, covariates, offsets, offsets_formula)
 
     @property
     def nb_iteration_done(self):
diff --git a/test.py b/test.py
index 2dc78415..ecc3d983 100644
--- a/test.py
+++ b/test.py
@@ -13,13 +13,14 @@ covariates = None
 offsets = None
 # counts, covariates, offsets = get_simulated_count_data(seed = 0)
 
-pca = PLNPCA(counts, covariates, offsets, ranks=[3, 4])
-pca.fit()
+# pca = PLNPCA(counts, covariates, offsets, ranks=[3, 4])
+# pca.fit(tol = 0.1)
 
 # pca.fit()
 # print(pca)
-
-# pln = PLN(counts, covariates, offsets)
+data = pd.DataFrame(counts)
+print("data :", data)
+# pln = PLN("counts~1", data)
 # pln.fit()
 # print(pln)
 # pcamodel = pca.best_model()
-- 
GitLab


From 3f97ec5aa7fcea468206122279183b4695326b2e Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 28 Apr 2023 15:31:00 +0200
Subject: [PATCH 05/24] add a poisson regressor instead of a regression on the
 logarithm.

---
 pyPLNmodels/_utils.py | 158 ++++++++++++++++++++++++------------------
 1 file changed, 90 insertions(+), 68 deletions(-)

diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 5e2d73bb..c8d039c6 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -327,6 +327,18 @@ def check_two_dimensions_are_equal(
         )
 
 
+def init_S(counts, covariates, offsets, beta, C, M):
+    n, rank = M.shape
+    batch_matrix = torch.matmul(C.unsqueeze(2), C.unsqueeze(1)).unsqueeze(0)
+    CW = torch.matmul(C.unsqueeze(0), M.unsqueeze(2)).squeeze()
+    common = torch.exp(offsets + covariates @ beta + CW).unsqueeze(2).unsqueeze(3)
+    prod = batch_matrix * common
+    hess_posterior = torch.sum(prod, axis=1) + torch.eye(rank).to(DEVICE)
+    inv_hess_posterior = -torch.inverse(hess_posterior)
+    hess_posterior = torch.diagonal(inv_hess_posterior, dim1=-2, dim2=-1)
+    return hess_posterior
+
+
 def format_data(data):
     if isinstance(data, pd.DataFrame):
         return torch.from_numpy(data.values).double().to(DEVICE)
@@ -406,87 +418,97 @@ def plot_ellipse(mean_x, mean_y, cov, ax):
     return pearson
 
 
-def get_components_simulation(dim, rank):
-    block_size = dim // rank
-    prev_state = torch.random.get_rng_state()
-    torch.random.manual_seed(0)
-    components = torch.zeros(dim, rank)
-    for column_number in range(rank):
-        components[
-            column_number * block_size : (column_number + 1) * block_size, column_number
-        ] = 1
-    components += torch.randn(dim, rank) / 8
-    torch.random.set_rng_state(prev_state)
-    return components.to(DEVICE)
-
-
-def get_simulation_offsets_cov_coef(n_samples, nb_cov, dim):
-    prev_state = torch.random.get_rng_state()
-    torch.random.manual_seed(0)
-    if nb_cov < 2:
-        covariates = None
-    else:
-        covariates = torch.randint(
-            low=-1,
-            high=2,
-            size=(n_samples, nb_cov - 1),
-            dtype=torch.float64,
-            device=DEVICE,
-        )
-    coef = torch.randn(nb_cov, dim, device=DEVICE)
-    offsets = torch.randint(
-        low=0, high=2, size=(n_samples, dim), dtype=torch.float64, device=DEVICE
-    )
-    torch.random.set_rng_state(prev_state)
-    return offsets, covariates, coef
-
-
-def get_simulated_count_data(
-    n_samples=100, dim=25, rank=5, nb_cov=1, return_true_param=False, seed=0
-):
-    components = get_components_simulation(dim, rank)
-    offsets, cov, true_coef = get_simulation_offsets_cov_coef(n_samples, nb_cov, dim)
-    true_covariance = torch.matmul(components, components.T)
-    counts, _, _ = sample_pln(components, true_coef, cov, offsets, seed=seed)
+def get_simulated_count_data(n=100, p=25, rank=25, d=1, return_true_param=False):
+    true_beta = torch.randn(d + 1, p, device=DEVICE)
+    C = torch.randn(p, rank, device=DEVICE) / 5
+    O = torch.ones((n, p), device=DEVICE) / 2
+    covariates = torch.randn((n, d), device=DEVICE)
+    true_Sigma = torch.matmul(C, C.T)
+    Y, _, _ = sample_PLN(C, true_beta, covariates, O)
     if return_true_param is True:
-        return counts, cov, offsets, true_covariance, true_coef
-    return counts, cov, offsets
+        return Y, covariates, O, true_Sigma, true_beta
+    return Y, covariates, O
 
 
-def get_real_count_data(n_samples=270, dim=100):
-    if n_samples > 297:
+def get_real_count_data(n=270, p=100):
+    if n > 297:
         warnings.warn(
-            f"\nTaking the whole 270 samples of the dataset. Requested:n_samples={n_samples}, returned:270"
+            f"\nTaking the whole 270 samples of the dataset. Requested:n={n}, returned:270"
         )
-        n_samples = 270
-    if dim > 100:
+        n = 270
+    if p > 100:
         warnings.warn(
-            f"\nTaking the whole 100 variables. Requested:dim={dim}, returned:100"
+            f"\nTaking the whole 100 variables. Requested:p={p}, returned:100"
         )
         dim = 100
-    counts = pd.read_csv("../example_data/real_data/Y_mark.csv").values[
-        :n_samples, :dim
-    ]
-    print(f"Returning dataset of size {counts.shape}")
-    return counts
+    Y = pd.read_csv("../example_data/real_data/Y_mark.csv").values[:n, :p]
+    print(f"Returning dataset of size {Y.shape}")
+    return Y
 
 
-def closest(lst, element):
+def closest(lst, K):
     lst = np.asarray(lst)
-    idx = (np.abs(lst - element)).argmin()
+    idx = (np.abs(lst - K)).argmin()
     return lst[idx]
 
 
-def check_dimensions_are_equal(tens1, tens2):
-    if tens1.shape[0] != tens2.shape[0] or tens1.shape[1] != tens2.shape[1]:
-        raise ValueError("Tensors should have the same size.")
+class poissonReg:
+    """Poisson regressor class."""
+
+    def __init__(self):
+        """No particular initialization is needed."""
+        pass
 
+    def fit(self, Y, O, covariates, Niter_max=300, tol=0.001, lr=0.005, verbose=False):
+        """Run a gradient ascent to maximize the log likelihood, using
+        pytorch autodifferentiation. The log likelihood considered is
+        the one from a poisson regression model. It is roughly the
+        same as PLN without the latent layer Z.
 
-def to_tensor(obj):
-    if isinstance(obj, np.ndarray):
-        return torch.from_numpy(obj)
-    if isinstance(obj, torch.Tensor):
-        return obj
-    if isinstance(obj, pd.DataFrame):
-        return torch.from_numpy(obj.values)
-    raise TypeError("Please give either a nd.array or torch.Tensor or pd.DataFrame")
+        Args:
+                        Y: torch.tensor. Counts with size (n,p)
+            0: torch.tensor. Offset, size (n,p)
+            covariates: torch.tensor. Covariates, size (n,d)
+            Niter_max: int, optional. The maximum number of iteration.
+                Default is 300.
+            tol: non negative float, optional. The tolerance criteria.
+                Will stop if the norm of the gradient is less than
+                or equal to this threshold. Default is 0.001.
+            lr: positive float, optional. Learning rate for the gradient ascent.
+                Default is 0.005.
+            verbose: bool, optional. If True, will print some stats.
+
+        Returns : None. Update the parameter beta. You can access it
+                by calling self.beta.
+        """
+        # Initialization of beta of size (d,p)
+        beta = torch.rand(
+            (covariates.shape[1], Y.shape[1]), device=DEVICE, requires_grad=True
+        )
+        optimizer = torch.optim.Rprop([beta], lr=lr)
+        i = 0
+        grad_norm = 2 * tol  # Criterion
+        while i < Niter_max and grad_norm > tol:
+            loss = -compute_poissreg_log_like(Y, O, covariates, beta)
+            loss.backward()
+            optimizer.step()
+            grad_norm = torch.norm(beta.grad)
+            beta.grad.zero_()
+            i += 1
+            if verbose:
+                if i % 10 == 0:
+                    print("log like : ", -loss)
+                    print("grad_norm : ", grad_norm)
+                if i < Niter_max:
+                    print("Tolerance reached in {} iterations".format(i))
+                else:
+                    print("Maxium number of iterations reached")
+        self.beta = beta
+
+
+def compute_poissreg_log_like(Y, O, covariates, beta):
+    """Compute the log likelihood of a Poisson regression."""
+    # Matrix multiplication of X and beta.
+    XB = torch.matmul(covariates.unsqueeze(1), beta.unsqueeze(0)).squeeze()
+    # Returns the formula of the log likelihood of a poisson regression model.
+    return torch.sum(-torch.exp(O + XB) + torch.multiply(Y, O + XB))
-- 
GitLab


From 428db818e16b5fd28999a9026be5189c1a77fa5f Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 28 Apr 2023 15:41:24 +0200
Subject: [PATCH 06/24] implement the right intialization for C and Sigma.

---
 pyPLNmodels/_utils.py | 19 ++++++-------------
 1 file changed, 6 insertions(+), 13 deletions(-)

diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index c8d039c6..6b008f14 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -87,7 +87,7 @@ class PLNPlotArgs:
             plt.savefig(name_doss)
 
 
-def init_sigma(counts, covariates, coef):
+def init_covariance(counts, covariates, coef):
     """Initialization for covariance for the PLN model. Take the log of counts
     (careful when counts=0), remove the covariates effects X@coef and
     then do as a MLE for Gaussians samples.
@@ -99,9 +99,7 @@ def init_sigma(counts, covariates, coef):
     Returns : torch.tensor of size (p,p).
     """
     log_y = torch.log(counts + (counts == 0) * math.exp(-2))
-    log_y_centered = (
-        log_y - torch.matmul(covariates.unsqueeze(1), coef.unsqueeze(0)).squeeze()
-    )
+    log_y_centered = log_y - torch.mean(log_y, axis=0)
     # MLE in a Gaussian setting
     n_samples = counts.shape[0]
     sigma_hat = 1 / (n_samples - 1) * (log_y_centered.T) @ log_y_centered
@@ -121,7 +119,7 @@ def init_components(counts, covariates, coef, rank):
     Returns :
         torch.tensor of size (p,rank). The initialization of components.
     """
-    sigma_hat = init_sigma(counts, covariates, coef).detach()
+    sigma_hat = init_covariance(counts, covariates, coef).detach()
     components = components_from_covariance(sigma_hat, rank)
     return components
 
@@ -241,13 +239,8 @@ def components_from_covariance(covariance, rank):
     return requested_components
 
 
-def init_coef(counts, covariates):
-    log_y = torch.log(counts + (counts == 0) * math.exp(-2))
-    log_y = log_y.to(DEVICE)
-    return torch.matmul(
-        torch.inverse(torch.matmul(covariates.T, covariates)),
-        torch.matmul(covariates.T, log_y),
-    )
+def init_coef(counts, covariates, offsets):
+    poiss_reg = PoissonReg()
 
 
 def log_stirling(integer):
@@ -452,7 +445,7 @@ def closest(lst, K):
     return lst[idx]
 
 
-class poissonReg:
+class PoissonReg:
     """Poisson regressor class."""
 
     def __init__(self):
-- 
GitLab


From 3e4aac90174c2b55e494584e82e8520977c9bce2 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 28 Apr 2023 15:47:58 +0200
Subject: [PATCH 07/24] to_tensor function retrived back from older commits.

---
 pyPLNmodels/_utils.py | 10 ++++++++++
 1 file changed, 10 insertions(+)

diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 6b008f14..01b68846 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -505,3 +505,13 @@ def compute_poissreg_log_like(Y, O, covariates, beta):
     XB = torch.matmul(covariates.unsqueeze(1), beta.unsqueeze(0)).squeeze()
     # Returns the formula of the log likelihood of a poisson regression model.
     return torch.sum(-torch.exp(O + XB) + torch.multiply(Y, O + XB))
+
+
+def to_tensor(obj):
+    if isinstance(obj, np.ndarray):
+        return torch.from_numpy(obj)
+    if isinstance(obj, torch.Tensor):
+        return obj
+    if isinstance(obj, pd.DataFrame):
+        return torch.from_numpy(obj.values)
+    raise TypeError("Please give either a nd.array or torch.Tensor or pd.DataFrame")
-- 
GitLab


From e8b8e146ffd12aba9a7f3139ac0f30ff0f4a5661 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 28 Apr 2023 15:55:06 +0200
Subject: [PATCH 08/24] write the right portions of file

---
 pyPLNmodels/_utils.py | 77 ++++++++++++++++++++++++++++++++-----------
 1 file changed, 57 insertions(+), 20 deletions(-)

diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 01b68846..326d5279 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -411,37 +411,74 @@ def plot_ellipse(mean_x, mean_y, cov, ax):
     return pearson
 
 
-def get_simulated_count_data(n=100, p=25, rank=25, d=1, return_true_param=False):
-    true_beta = torch.randn(d + 1, p, device=DEVICE)
-    C = torch.randn(p, rank, device=DEVICE) / 5
-    O = torch.ones((n, p), device=DEVICE) / 2
-    covariates = torch.randn((n, d), device=DEVICE)
-    true_Sigma = torch.matmul(C, C.T)
-    Y, _, _ = sample_PLN(C, true_beta, covariates, O)
+def get_components_simulation(dim, rank):
+    block_size = dim // rank
+    prev_state = torch.random.get_rng_state()
+    torch.random.manual_seed(0)
+    components = torch.zeros(dim, rank)
+    for column_number in range(rank):
+        components[
+            column_number * block_size : (column_number + 1) * block_size, column_number
+        ] = 1
+    components += torch.randn(dim, rank) / 8
+    torch.random.set_rng_state(prev_state)
+    return components.to(DEVICE)
+
+
+def get_simulation_offsets_cov_coef(n_samples, nb_cov, dim):
+    prev_state = torch.random.get_rng_state()
+    torch.random.manual_seed(0)
+    if nb_cov < 2:
+        covariates = None
+    else:
+        covariates = torch.randint(
+            low=-1,
+            high=2,
+            size=(n_samples, nb_cov - 1),
+            dtype=torch.float64,
+            device=DEVICE,
+        )
+    coef = torch.randn(nb_cov, dim, device=DEVICE)
+    offsets = torch.randint(
+        low=0, high=2, size=(n_samples, dim), dtype=torch.float64, device=DEVICE
+    )
+    torch.random.set_rng_state(prev_state)
+    return offsets, covariates, coef
+
+
+def get_simulated_count_data(
+    n_samples=100, dim=25, rank=5, nb_cov=1, return_true_param=False, seed=0
+):
+    components = get_components_simulation(dim, rank)
+    offsets, cov, true_coef = get_simulation_offsets_cov_coef(n_samples, nb_cov, dim)
+    true_covariance = torch.matmul(components, components.T)
+    counts, _, _ = sample_pln(components, true_coef, cov, offsets, seed=seed)
     if return_true_param is True:
-        return Y, covariates, O, true_Sigma, true_beta
-    return Y, covariates, O
+        return counts, cov, offsets, true_covariance, true_coef
+    return counts, cov, offsets
 
 
-def get_real_count_data(n=270, p=100):
-    if n > 297:
+def get_real_count_data(n_samples=270, dim=100):
+    if n_samples > 297:
         warnings.warn(
-            f"\nTaking the whole 270 samples of the dataset. Requested:n={n}, returned:270"
+            f"\nTaking the whole 270 samples of the dataset. Requested:n_samples={n_samples}, returned:270"
         )
-        n = 270
-    if p > 100:
+        n_samples = 270
+    if dim > 100:
         warnings.warn(
-            f"\nTaking the whole 100 variables. Requested:p={p}, returned:100"
+            f"\nTaking the whole 100 variables. Requested:dim={dim}, returned:100"
         )
         dim = 100
-    Y = pd.read_csv("../example_data/real_data/Y_mark.csv").values[:n, :p]
-    print(f"Returning dataset of size {Y.shape}")
-    return Y
+    counts = pd.read_csv("../example_data/real_data/Y_mark.csv").values[
+        :n_samples, :dim
+    ]
+    print(f"Returning dataset of size {counts.shape}")
+    return counts
 
 
-def closest(lst, K):
+def closest(lst, element):
     lst = np.asarray(lst)
-    idx = (np.abs(lst - K)).argmin()
+    idx = (np.abs(lst - element)).argmin()
     return lst[idx]
 
 
-- 
GitLab


From e593f8e5166ba0e4548ddaf1c4b0f90f0ee59457 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 28 Apr 2023 15:56:33 +0200
Subject: [PATCH 09/24] add check_dimensions functions from previous commit

---
 pyPLNmodels/_utils.py | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 326d5279..636cf2ed 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -552,3 +552,8 @@ def to_tensor(obj):
     if isinstance(obj, pd.DataFrame):
         return torch.from_numpy(obj.values)
     raise TypeError("Please give either a nd.array or torch.Tensor or pd.DataFrame")
+
+
+def check_dimensions_are_equal(tens1, tens2):
+    if tens1.shape[0] != tens2.shape[0] or tens1.shape[1] != tens2.shape[1]:
+        raise ValueError("Tensors should have the same size.")
-- 
GitLab


From 845da678e6aad76bad73c4464568fe39da4dc6b6 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 28 Apr 2023 16:01:47 +0200
Subject: [PATCH 10/24] doing the right initialization. Need to check for the
 right tolerance now.

---
 pyPLNmodels/_utils.py |  4 +++-
 pyPLNmodels/models.py | 10 +++++-----
 2 files changed, 8 insertions(+), 6 deletions(-)

diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 636cf2ed..5de18b2c 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -241,6 +241,8 @@ def components_from_covariance(covariance, rank):
 
 def init_coef(counts, covariates, offsets):
     poiss_reg = PoissonReg()
+    poiss_reg.fit(counts, covariates, offsets)
+    return poiss_reg.beta
 
 
 def log_stirling(integer):
@@ -489,7 +491,7 @@ class PoissonReg:
         """No particular initialization is needed."""
         pass
 
-    def fit(self, Y, O, covariates, Niter_max=300, tol=0.001, lr=0.005, verbose=False):
+    def fit(self, Y, covariates, O, Niter_max=300, tol=0.001, lr=0.005, verbose=False):
         """Run a gradient ascent to maximize the log likelihood, using
         pytorch autodifferentiation. The log likelihood considered is
         the one from a poisson regression model. It is roughly the
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 68385eec..3fa83b1e 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -22,7 +22,7 @@ from ._closed_forms import (
 from .elbos import elbo_plnpca, elbo_zi_pln, profiled_elbo_pln
 from ._utils import (
     PLNPlotArgs,
-    init_sigma,
+    init_covariance,
     init_components,
     init_coef,
     check_two_dimensions_are_equal,
@@ -109,7 +109,7 @@ class _PLN(ABC):
         return self.covariates.shape[1]
 
     def smart_init_coef(self):
-        self._coef = init_coef(self._counts, self._covariates)
+        self._coef = init_coef(self._counts, self._covariates, self._offsets)
 
     def random_init_coef(self):
         self._coef = torch.randn((self.nb_cov, self.dim), device=DEVICE)
@@ -221,8 +221,8 @@ class _PLN(ABC):
     def print_end_of_fitting_message(self, stop_condition, tol):
         if stop_condition is True:
             print(
-                f"Tolerance {tol} reached"
-                f"n {self.plotargs.iteration_number} iterations"
+                f"Tolerance {tol} reached "
+                f"in {self.plotargs.iteration_number} iterations"
             )
         else:
             print(
@@ -959,7 +959,7 @@ class ZIPLN(PLN):
     # should change the good initialization, especially for _coef_inflation
     def smart_init_model_parameters(self):
         super().smart_init_model_parameters()
-        self._covariance = init_sigma(
+        self._covariance = init_covariance(
             self._counts, self._covariates, self._offsets, self._coef
         )
         self._coef_inflation = torch.randn(self.nb_cov, self.dim)
-- 
GitLab


From 87d2dbea2e35de84b19470851bb43221c9e248a3 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 28 Apr 2023 16:02:37 +0200
Subject: [PATCH 11/24] minor changes in the test file.

---
 test.py | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/test.py b/test.py
index ecc3d983..b9d5e686 100644
--- a/test.py
+++ b/test.py
@@ -13,13 +13,12 @@ covariates = None
 offsets = None
 # counts, covariates, offsets = get_simulated_count_data(seed = 0)
 
-# pca = PLNPCA(counts, covariates, offsets, ranks=[3, 4])
-# pca.fit(tol = 0.1)
+pca = PLNPCA(counts, covariates, offsets, ranks=[3, 4])
+pca.fit(tol=0.1)
 
 # pca.fit()
 # print(pca)
-data = pd.DataFrame(counts)
-print("data :", data)
+# data = pd.DataFrame(counts)
 # pln = PLN("counts~1", data)
 # pln.fit()
 # print(pln)
-- 
GitLab


From 8fcc6cce669b56115b15ca48c22f2ace056805f9 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Tue, 2 May 2023 09:37:18 +0200
Subject: [PATCH 12/24] add formula init and can also save parameters in order
 to use them in another init

---
 pyPLNmodels/_utils.py | 105 ++++++++++++++---
 pyPLNmodels/models.py | 269 ++++++++++++++++++++++++++++--------------
 2 files changed, 269 insertions(+), 105 deletions(-)

diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 5de18b2c..4d97045b 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -1,5 +1,6 @@
 import math  # pylint:disable=[C0114]
 import warnings
+import os
 
 import matplotlib.pyplot as plt
 import numpy as np
@@ -8,6 +9,7 @@ import torch.linalg as TLA
 import pandas as pd
 from matplotlib.patches import Ellipse
 from matplotlib import transforms
+from patsy import dmatrices
 
 torch.set_default_dtype(torch.float64)
 
@@ -192,15 +194,11 @@ def sample_pln(components, coef, covariates, offsets, _coef_inflation=None, seed
         torch.random.manual_seed(seed)
     n_samples = offsets.shape[0]
     rank = components.shape[1]
-    full_of_ones = torch.ones((n_samples, 1))
     if covariates is None:
-        covariates = full_of_ones
+        XB = 0
     else:
-        covariates = torch.stack((full_of_ones, covariates), axis=1).squeeze()
-    gaussian = (
-        torch.mm(torch.randn(n_samples, rank, device=DEVICE), components.T)
-        + covariates @ coef
-    )
+        XB = covariates @ coef
+    gaussian = torch.mm(torch.randn(n_samples, rank, device=DEVICE), components.T) + XB
     parameter = torch.exp(offsets + gaussian)
     if _coef_inflation is not None:
         print("ZIPLN is sampled")
@@ -348,7 +346,10 @@ def format_data(data):
 
 def format_model_param(counts, covariates, offsets, offsets_formula):
     counts = format_data(counts)
-    covariates = prepare_covariates(covariates, counts.shape[0])
+    if covariates is None:
+        covariates = torch.zeros(counts.shape[0])
+    else:
+        covariates = format_data(covariates)
     if offsets is None:
         if offsets_formula == "logsum":
             print("Setting the offsets as the log of the sum of counts")
@@ -362,12 +363,16 @@ def format_model_param(counts, covariates, offsets, offsets_formula):
     return counts, covariates, offsets
 
 
-def prepare_covariates(covariates, n_samples):
-    full_of_ones = torch.full((n_samples, 1), 1, device=DEVICE).double()
-    if covariates is None:
-        return full_of_ones
+def remove_useless_intercepts(covariates):
     covariates = format_data(covariates)
-    return torch.stack((full_of_ones, covariates), axis=1).squeeze()
+    if covariates.shape[1] < 2:
+        return covariates
+    first_column = covariates[:, 0]
+    second_column = covariates[:, 1]
+    diff = first_column - second_column
+    if torch.sum(torch.abs(diff - diff[0])) == 0:
+        return covariates[:, 1:]
+    return covariates
 
 
 def check_data_shape(counts, covariates, offsets):
@@ -430,13 +435,13 @@ def get_components_simulation(dim, rank):
 def get_simulation_offsets_cov_coef(n_samples, nb_cov, dim):
     prev_state = torch.random.get_rng_state()
     torch.random.manual_seed(0)
-    if nb_cov < 2:
+    if nb_cov == 0:
         covariates = None
     else:
         covariates = torch.randint(
             low=-1,
             high=2,
-            size=(n_samples, nb_cov - 1),
+            size=(n_samples, nb_cov),
             dtype=torch.float64,
             device=DEVICE,
         )
@@ -559,3 +564,73 @@ def to_tensor(obj):
 def check_dimensions_are_equal(tens1, tens2):
     if tens1.shape[0] != tens2.shape[0] or tens1.shape[1] != tens2.shape[1]:
         raise ValueError("Tensors should have the same size.")
+
+
+def load_model(path_of_directory):
+    os.chdir(path_of_directory)
+    all_files = os.listdir()
+    data = {}
+    for filename in all_files:
+        if len(filename) > 4:
+            if filename[-4:] == ".csv":
+                parameter = filename[:-4]
+                # data[parameter] = pd.read_csv(filename, header=None).values
+                try:
+                    data[parameter] = pd.read_csv(filename, header=None).values
+                except pd.errors.EmptyDataError as err:
+                    print(
+                        f"Can t load {parameter} since empty. Standard initialization will be performed"
+                    )
+    os.chdir("../")
+    return data
+
+
+def load_plnpca(path_of_directory, ranks=None):
+    os.chdir(path_of_directory)
+    if ranks is None:
+        dirnames = os.listdir()
+        ranks = []
+        for dirname in dirnames:
+            try:
+                rank = int(dirname[-1])
+            except ValueError:
+                print(
+                    f"Can t load the model {dirname}. End of {dirname} should be an int"
+                )
+            ranks.append(rank)
+    datas = {}
+    for rank in ranks:
+        datas[rank] = load_model(f"PLNPCA_rank_{rank}")
+    os.chdir("../")
+    return datas
+
+
+def check_right_rank(data, rank):
+    data_rank = data["latent_mean"].shape[1]
+    if data_rank != rank:
+        raise RuntimeError(
+            f"Wrong rank during initialization."
+            f" Got rank {rank} and data with rank {data_rank}."
+        )
+
+
+def extract_data_from_formula(formula, data):
+    dmatrix = dmatrices(formula, data=data)
+    counts = dmatrix[0]
+    covariates = dmatrix[1]
+    if len(covariates) > 0:
+        covariates = remove_useless_intercepts(covariates)
+    offsets = data.get("offsets", None)
+    return counts, covariates, offsets
+
+
+def is_dict_of_dict(dictionnary):
+    if isinstance(dictionnary[list(dictionnary.keys())[0]], dict):
+        return True
+    return False
+
+
+def get_dict_initialization(rank, dict_of_dict):
+    if dict_of_dict is None:
+        return None
+    return dict_of_dict[rank]
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 3fa83b1e..e09c25df 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -4,6 +4,7 @@ import pickle
 import warnings
 import os
 from functools import singledispatchmethod
+from collections.abc import Iterable
 
 import pandas as pd
 import torch
@@ -33,9 +34,13 @@ from ._utils import (
     nice_string_of_dict,
     plot_ellipse,
     closest,
-    prepare_covariates,
     to_tensor,
     check_dimensions_are_equal,
+    check_right_rank,
+    remove_useless_intercepts,
+    is_dict_of_dict,
+    extract_data_from_formula,
+    get_dict_initialization,
 )
 
 if torch.cuda.is_available():
@@ -71,7 +76,14 @@ class _PLN(ABC):
     _latent_mean: torch.Tensor
 
     @singledispatchmethod
-    def __init__(self, counts, covariates=None, offsets=None, offsets_formula="logsum"):
+    def __init__(
+        self,
+        counts,
+        covariates=None,
+        offsets=None,
+        offsets_formula="logsum",
+        dict_initialization=None,
+    ):
         """
         Simple initialization method.
         """
@@ -82,15 +94,30 @@ class _PLN(ABC):
         check_data_shape(self._counts, self._covariates, self._offsets)
         self._fitted = False
         self.plotargs = PLNPlotArgs(self.WINDOW)
-        print("normal init")
+        if dict_initialization is not None:
+            for key, value in dict_initialization.items():
+                value = torch.from_numpy(dict_initialization[key])
+                setattr(self, key, value)
 
     @__init__.register(str)
-    def _(self, formula: str, data: pd.DataFrame, offsets_formula="logsum"):
+    def _(
+        self,
+        formula: str,
+        data: dict,
+        offsets_formula="logsum",
+        dict_initialization=None,
+    ):
         dmatrix = dmatrices(formula, data=data)
         counts = dmatrix[0]
         covariates = dmatrix[1]
-        offsets = None
-        super().__init__(counts, covariates, offsets, offsets_formula)
+        if len(covariates) > 0:
+            covariates = remove_useless_intercepts(covariates)
+        offsets = data.get("offsets", None)
+        self.__init__(counts, covariates, offsets, offsets_formula, dict_initialization)
+
+    @property
+    def fitted(self):
+        return
 
     @property
     def nb_iteration_done(self):
@@ -155,10 +182,9 @@ class _PLN(ABC):
         nb_max_iteration=50000,
         lr=0.01,
         class_optimizer=torch.optim.Rprop,
-        tol=1e-6,
+        tol=1e-5,
         do_smart_init=True,
         verbose=False,
-        keep_going=False,
     ):
         """
         Main function of the class. Fit a PLN to the data.
@@ -176,9 +202,9 @@ class _PLN(ABC):
         self.print_beginning_message()
         self.beginnning_time = time.time()
 
-        if keep_going is False:
+        if self._fitted is False:
             self.init_parameters(do_smart_init)
-        if self._fitted is True and keep_going is True:
+        else:
             self.beginnning_time -= self.plotargs.running_times[-1]
         self.optim = class_optimizer(self.list_of_parameters_needing_gradient, lr=lr)
         stop_condition = False
@@ -366,7 +392,11 @@ class _PLN(ABC):
 
     @property
     def model_in_a_dict(self):
-        return self.dict_data | self.model_parameters | self.latent_parameters
+        return self.dict_data | self.dict_parameters
+
+    @property
+    def dict_parameters(self):
+        return self.model_parameters | self.latent_parameters
 
     @property
     def coef(self):
@@ -399,7 +429,7 @@ class _PLN(ABC):
     def save(self, path_of_directory="./"):
         path = f"{path_of_directory}/{self.model_path}/"
         os.makedirs(path, exist_ok=True)
-        for key, value in self.model_in_a_dict.items():
+        for key, value in self.dict_parameters.items():
             filename = f"{path}/{key}.csv"
             if isinstance(value, torch.Tensor):
                 pd.DataFrame(np.array(value.cpu().detach())).to_csv(
@@ -409,16 +439,6 @@ class _PLN(ABC):
                 pd.DataFrame(np.array([value])).to_csv(
                     filename, header=None, index=None
                 )
-        self._fitted = True
-
-    def load(self, path_of_directory="./"):
-        path = f"{path_of_directory}/{self.model_path}/"
-        for key, value in self.model_in_a_dict.items():
-            value = torch.from_numpy(
-                pd.read_csv(path + key + ".csv", header=None).values
-            )
-            setattr(self, key, value)
-        self.put_parameters_to_device()
 
     @property
     def counts(self):
@@ -479,13 +499,17 @@ class _PLN(ABC):
         return self.covariance
 
     def predict(self, covariates=None):
-        if isinstance(covariates, torch.Tensor):
-            if covariates.shape[-1] != self.nb_cov - 1:
-                error_string = f"X has wrong shape ({covariates.shape}).Should"
-                error_string += f" be ({self.n_samples, self.nb_cov-1})."
-                raise RuntimeError(error_string)
-        covariates_with_ones = prepare_covariates(covariates, self.n_samples)
-        return covariates_with_ones @ self.coef
+        if covariates is None:
+            return self.coef[0, :]
+        if covariates.shape[-1] != self.nb_cov:
+            error_string = f"X has wrong shape ({covariates.shape}).Should"
+            error_string += f" be ({self.n_samples, self.nb_cov})."
+            raise RuntimeError(error_string)
+        return covariates @ self.coef
+
+    @property
+    def model_path(self):
+        return f"{self.NAME}_nbcov_{self.nb_cov}_dim_{self.dim}"
 
 
 # need to do a good init for M and S
@@ -511,12 +535,10 @@ class PLN(_PLN):
         self.random_init_latent_parameters()
 
     def random_init_latent_parameters(self):
-        self._latent_var = 1 / 2 * torch.ones((self.n_samples, self.dim)).to(DEVICE)
-        self._latent_mean = torch.ones((self.n_samples, self.dim)).to(DEVICE)
-
-    @property
-    def model_path(self):
-        return self.NAME
+        if not hasattr(self, "_latent_var"):
+            self._latent_var = 1 / 2 * torch.ones((self.n_samples, self.dim)).to(DEVICE)
+        if not hasattr(self, "_latent_mean"):
+            self._latent_mean = torch.ones((self.n_samples, self.dim)).to(DEVICE)
 
     @property
     def list_of_parameters_needing_gradient(self):
@@ -595,7 +617,10 @@ class PLN(_PLN):
         pass
 
 
+## en train d'essayer de faire une seule init pour_PLNPCA
 class PLNPCA:
+    NAME = "PLNPCA"
+
     @singledispatchmethod
     def __init__(
         self,
@@ -604,44 +629,84 @@ class PLNPCA:
         offsets=None,
         offsets_formula="logsum",
         ranks=range(1, 5),
+        dict_of_dict_initialization=None,
     ):
+        self.init_data(counts, covariates, offsets, offsets_formula)
+        self.init_models(ranks, dict_of_dict_initialization)
+
+    def init_data(self, counts, covariates, offsets, offsets_formula):
         self._counts, self._covariates, self._offsets = format_model_param(
             counts, covariates, offsets, offsets_formula
         )
         check_data_shape(self._counts, self._covariates, self._offsets)
         self._fitted = False
-        self.init_models(ranks)
 
     @__init__.register(str)
-    def _(self, formula: str, data: pd.DataFrame):
-        print("formula")
+    def _(
+        self,
+        formula: str,
+        data: dict,
+        offsets_formula="logsum",
+        ranks=range(1, 5),
+        dict_of_dict_initialization=None,
+    ):
+        counts, covariates, offsets = extract_data_from_formula(formula, data)
+        self.__init__(
+            counts,
+            covariates,
+            offsets,
+            offsets_formula,
+            ranks,
+            dict_of_dict_initialization,
+        )
 
-    def init_models(self, ranks):
-        if isinstance(ranks, (list, np.ndarray)):
-            self.ranks = ranks
-            self.dict_models = {}
+    def init_models(self, ranks, dict_of_dict_initialization):
+        if isinstance(ranks, (Iterable, np.ndarray)):
+            self.models = []
             for rank in ranks:
-                if isinstance(rank, (int, np.int64)):
-                    self.dict_models[rank] = _PLNPCA(
-                        rank, self._counts, self._covariates, self._offsets
+                if isinstance(rank, (int, np.integer)):
+                    dict_initialization = get_dict_initialization(
+                        rank, dict_of_dict_initialization
+                    )
+                    self.models.append(
+                        _PLNPCA(
+                            self._counts,
+                            self._covariates,
+                            self._offsets,
+                            rank,
+                            dict_initialization,
+                        )
                     )
                 else:
                     raise TypeError(
-                        "Please instantiate with either a list\
-                              of integers or an integer."
+                        f"Please instantiate with either a list "
+                        f"of integers or an integer."
                     )
-        elif isinstance(ranks, int):
-            self.ranks = [ranks]
-            self.dict_models = {ranks: _PLNPCA(ranks)}
+        elif isinstance(ranks, (int, np.integer)):
+            dict_initialization = get_dict_initialization(
+                ranks, dict_of_dict_initialization
+            )
+            self.models = [
+                _PLNPCA(
+                    self._counts,
+                    self._covariates,
+                    self._offsets,
+                    rank,
+                    dict_initialization,
+                )
+            ]
         else:
             raise TypeError(
-                "Please instantiate with either a list of \
-                        integers or an integer."
+                f"Please instantiate with either a list " f"of integers or an integer."
             )
 
     @property
-    def models(self):
-        return list(self.dict_models.values())
+    def ranks(self):
+        return [model.rank for model in self.models]
+
+    @property
+    def dict_models(self):
+        return {model.rank: model for model in self.models}
 
     def print_beginning_message(self):
         return f"Adjusting {len(self.ranks)} PLN models for PCA analysis \n"
@@ -650,6 +715,10 @@ class PLNPCA:
     def dim(self):
         return self[self.ranks[0]].dim
 
+    @property
+    def nb_cov(self):
+        return self[self.ranks[0]].nb_cov
+
     ## should do something for this weird init. pb: if doing the init of self._counts etc
     ## only in PLNPCA, then we don't do it for each _PLNPCA but then PLN is not doing it.
     def fit(
@@ -657,10 +726,9 @@ class PLNPCA:
         nb_max_iteration=100000,
         lr=0.01,
         class_optimizer=torch.optim.Rprop,
-        tol=1e-6,
+        tol=1e-3,
         do_smart_init=True,
         verbose=False,
-        keep_going=False,
     ):
         self.print_beginning_message()
         for pca in self.dict_models.values():
@@ -671,7 +739,6 @@ class PLNPCA:
                 tol,
                 do_smart_init,
                 verbose,
-                keep_going,
             )
         self.print_ending_message()
 
@@ -747,13 +814,16 @@ class PLNPCA:
             return self[self.best_AIC_model_rank]
         raise ValueError(f"Unknown criterion {criterion}")
 
-    def save(self, path_of_directory="./"):
+    def save(self, path_of_directory="./", ranks=None):
+        if ranks is None:
+            ranks = self.ranks
         for model in self.models:
-            model.save(path_of_directory)
+            if model.rank in ranks:
+                model.save(f"{path_of_directory}/{self.model_path}")
 
-    def load(self, path_of_directory="./"):
-        for model in self.models:
-            model.load(path_of_directory)
+    @property
+    def model_path(self):
+        return f"{self.NAME}_nbcov_{self.nb_cov}_dim_{self.dim}"
 
     @property
     def n_samples(self):
@@ -797,28 +867,41 @@ class PLNPCA:
         return ".BIC, .AIC, .loglikes"
 
 
+# Here, setting the value for each key in dict_parameters
 class _PLNPCA(_PLN):
     NAME = "PLNPCA"
     _components: torch.Tensor
 
-    def __init__(self, rank, counts, covariates, offsets):
+    @singledispatchmethod
+    def __init__(self, counts, covariates, offsets, rank, dict_initialization=None):
         self._rank = rank
         self._counts = counts
         self._covariates = covariates
         self._offsets = offsets
+        self.check_if_rank_is_too_high()
+        if dict_initialization is not None:
+            self.set_init_parameters(dict_initialization)
+        self._fitted = False
+        self.plotargs = PLNPlotArgs(self.WINDOW)
+
+    def set_init_parameters(self, dict_parameters):
+        for key, array in dict_parameters.items():
+            array = format_data(array)
+            setattr(self, key, array)
 
+    def check_if_rank_is_too_high(self):
         if self.dim < self.rank:
-            warning_string = f"\nThe requested rank of approximation {self.rank} \
-                is greater than the number of variables {self.dim}. \
-                Setting rank to {self.dim}"
+            warning_string = (
+                f"\nThe requested rank of approximation {self.rank} "
+                f"is greater than the number of variables {self.dim}. "
+                f"Setting rank to {self.dim}"
+            )
             warnings.warn(warning_string)
             self._rank = self.dim
-        self._fitted = False
-        self.plotargs = PLNPlotArgs(self.WINDOW)
 
     @property
     def model_path(self):
-        return f"{self.NAME}_{self._rank}_rank"
+        return f"{super().model_path}_rank_{self._rank}"
 
     @property
     def rank(self):
@@ -836,10 +919,12 @@ class _PLNPCA(_PLN):
         return {"coef": self.coef, "components": self.components}
 
     def smart_init_model_parameters(self):
-        super().smart_init_coef()
-        self._components = init_components(
-            self._counts, self._covariates, self._coef, self._rank
-        )
+        if not hasattr(self, "_coef"):
+            super().smart_init_coef()
+        if not hasattr(self, "_components"):
+            self._components = init_components(
+                self._counts, self._covariates, self._coef, self._rank
+            )
 
     def random_init_model_parameters(self):
         super().random_init_coef()
@@ -850,20 +935,22 @@ class _PLNPCA(_PLN):
         self._latent_mean = torch.ones((self.n_samples, self._rank)).to(DEVICE)
 
     def smart_init_latent_parameters(self):
-        self._latent_mean = (
-            init_latent_mean(
-                self._counts,
-                self._covariates,
-                self._offsets,
-                self._coef,
-                self._components,
+        if not hasattr(self, "_latent_mean"):
+            self._latent_mean = (
+                init_latent_mean(
+                    self._counts,
+                    self._covariates,
+                    self._offsets,
+                    self._coef,
+                    self._components,
+                )
+                .to(DEVICE)
+                .detach()
+            )
+        if not hasattr(self, "_latent_var"):
+            self._latent_var = (
+                1 / 2 * torch.ones((self.n_samples, self._rank)).to(DEVICE)
             )
-            .to(DEVICE)
-            .detach()
-        )
-        self._latent_var = 1 / 2 * torch.ones((self.n_samples, self._rank)).to(DEVICE)
-        self._latent_mean.requires_grad_(True)
-        self._latent_var.requires_grad_(True)
 
     @property
     def list_of_parameters_needing_gradient(self):
@@ -959,10 +1046,12 @@ class ZIPLN(PLN):
     # should change the good initialization, especially for _coef_inflation
     def smart_init_model_parameters(self):
         super().smart_init_model_parameters()
-        self._covariance = init_covariance(
-            self._counts, self._covariates, self._offsets, self._coef
-        )
-        self._coef_inflation = torch.randn(self.nb_cov, self.dim)
+        if not hasattr(self, "_covariance"):
+            self._covariance = init_covariance(
+                self._counts, self._covariates, self._coef
+            )
+        if not hasattr(self, "_coef_inflation"):
+            self._coef_inflation = torch.randn(self.nb_cov, self.dim)
 
     def random_init_latent_parameters(self):
         self._dirac = self._counts == 0
-- 
GitLab


From f82926961383e77b9da6423eea75d673584b0a80 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Tue, 2 May 2023 18:49:05 +0200
Subject: [PATCH 13/24] minor changes.

---
 pyPLNmodels/__init__.py      | 11 ++++++++++-
 pyPLNmodels/_closed_forms.py |  2 ++
 pyPLNmodels/_utils.py        |  6 +++++-
 pyPLNmodels/models.py        | 16 +++++++++++++++-
 4 files changed, 32 insertions(+), 3 deletions(-)

diff --git a/pyPLNmodels/__init__.py b/pyPLNmodels/__init__.py
index 15591263..d895c7cd 100644
--- a/pyPLNmodels/__init__.py
+++ b/pyPLNmodels/__init__.py
@@ -1,6 +1,12 @@
 from .models import PLNPCA, PLN  # pylint:disable=[C0114]
 from .elbos import profiled_elbo_pln, elbo_plnpca, elbo_pln
-from ._utils import get_simulated_count_data, get_real_count_data
+from ._utils import (
+    get_simulated_count_data,
+    get_real_count_data,
+    load_model,
+    load_plnpca,
+    load_pln,
+)
 
 __all__ = (
     "PLNPCA",
@@ -10,4 +16,7 @@ __all__ = (
     "elbo_pln",
     "get_simulated_count_data",
     "get_real_count_data",
+    "load_model",
+    "load_plnpca",
+    "load_pln",
 )
diff --git a/pyPLNmodels/_closed_forms.py b/pyPLNmodels/_closed_forms.py
index 783d2916..5e00c396 100644
--- a/pyPLNmodels/_closed_forms.py
+++ b/pyPLNmodels/_closed_forms.py
@@ -11,6 +11,8 @@ def closed_formula_covariance(covariates, latent_mean, latent_var, coef, n_sampl
 
 def closed_formula_coef(covariates, latent_mean):
     """Closed form for coef for the M step for the noPCA model."""
+    if torch.sum(torch.abs(covariates)) < 1e-15:
+        return torch.zeros(covariates.shape[1], latent_mean.shape[1])
     return torch.mm(
         torch.mm(torch.inverse(torch.mm(covariates.T, covariates)), covariates.T),
         latent_mean,
diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 4d97045b..3a1a9a7a 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -347,7 +347,7 @@ def format_data(data):
 def format_model_param(counts, covariates, offsets, offsets_formula):
     counts = format_data(counts)
     if covariates is None:
-        covariates = torch.zeros(counts.shape[0])
+        covariates = torch.zeros(counts.shape[0], 1)
     else:
         covariates = format_data(covariates)
     if offsets is None:
@@ -585,6 +585,10 @@ def load_model(path_of_directory):
     return data
 
 
+def load_pln(path_of_directory):
+    return load_model(path_of_directory)
+
+
 def load_plnpca(path_of_directory, ranks=None):
     os.chdir(path_of_directory)
     if ranks is None:
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index e09c25df..d01424ff 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -517,6 +517,9 @@ class PLN(_PLN):
     NAME = "PLN"
     coef: torch.Tensor
 
+    def get_class(self):
+        return PLN
+
     @property
     def description(self):
         return "full covariance model."
@@ -621,6 +624,9 @@ class PLN(_PLN):
 class PLNPCA:
     NAME = "PLNPCA"
 
+    def get_class(self):
+        return PLNPCA
+
     @singledispatchmethod
     def __init__(
         self,
@@ -872,6 +878,9 @@ class _PLNPCA(_PLN):
     NAME = "PLNPCA"
     _components: torch.Tensor
 
+    def get_class(self):
+        return _PLNPCA
+
     @singledispatchmethod
     def __init__(self, counts, covariates, offsets, rank, dict_initialization=None):
         self._rank = rank
@@ -884,6 +893,11 @@ class _PLNPCA(_PLN):
         self._fitted = False
         self.plotargs = PLNPlotArgs(self.WINDOW)
 
+    @__init__.register(str)
+    def _(self, formula, data, dict_initialization):
+        counts, covariates, offsets = extract_data_from_formula(formula, data)
+        self.__init__(counts, covariates, offsets, None, dict_initialization)
+
     def set_init_parameters(self, dict_parameters):
         for key, array in dict_parameters.items():
             array = format_data(array)
@@ -901,7 +915,7 @@ class _PLNPCA(_PLN):
 
     @property
     def model_path(self):
-        return f"{super().model_path}_rank_{self._rank}"
+        return f"{self.NAME}_rank_{self._rank}"
 
     @property
     def rank(self):
-- 
GitLab


From 150874a8485e2c00e7bf3707efcdbefea8cc369e Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Tue, 2 May 2023 18:49:50 +0200
Subject: [PATCH 14/24] Tried to implement all the fixtures in once. Getting
 trouble with pytest.

---
 tests/conftest.py    | 396 +++++++++++++++++++++++++++++++++++++++++++
 tests/test_common.py | 219 +++++-------------------
 2 files changed, 437 insertions(+), 178 deletions(-)

diff --git a/tests/conftest.py b/tests/conftest.py
index 1df7d25b..4f2d2970 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,3 +1,399 @@
 import sys
+import glob
+from functools import singledispatch
+
+import pytest
+from pytest_lazyfixture import lazy_fixture as lf
+from pyPLNmodels import load_model, load_plnpca
+from tests.import_fixtures import get_dict_fixtures
+from pyPLNmodels.models import PLN, _PLNPCA, PLNPCA
 
 sys.path.append("../")
+
+
+pln_full_fixture = get_dict_fixtures(PLN)
+plnpca_fixture = get_dict_fixtures(_PLNPCA)
+
+
+from tests.import_data import (
+    data_sim_0cov,
+    data_sim_2cov,
+    data_real,
+)
+
+
+counts_sim_0cov = data_sim_0cov["counts"]
+covariates_sim_0cov = data_sim_0cov["covariates"]
+offsets_sim_0cov = data_sim_0cov["offsets"]
+
+counts_sim_2cov = data_sim_2cov["counts"]
+covariates_sim_2cov = data_sim_2cov["covariates"]
+offsets_sim_2cov = data_sim_2cov["offsets"]
+
+counts_real = data_real["counts"]
+
+
+def add_fixture_to_dict(my_dict, string_fixture):
+    my_dict[string_fixture] = lf(string_fixture)
+    return my_dict
+
+
+def add_list_of_fixture_to_dict(
+    my_dict, name_of_list_of_fixtures, list_of_string_fixtures
+):
+    my_dict[name_of_list_of_fixtures] = []
+    for string_fixture in list_of_string_fixtures:
+        my_dict[name_of_list_of_fixtures].append(lf(string_fixture))
+    return my_dict
+
+
+RANK = 8
+RANKS = [2, 6]
+
+# dict_fixtures_models = []
+
+
+@singledispatch
+def convenient_plnpca(
+    counts,
+    covariates=None,
+    offsets=None,
+    offsets_formula=None,
+    dict_initialization=None,
+):
+    return _PLNPCA(
+        counts, covariates, offsets, rank=RANK, dict_initialization=dict_initialization
+    )
+
+
+@convenient_plnpca.register(str)
+def _(formula, data, offsets_formula, dict_initialization=None):
+    return _PLNPCA(formula, data, rank=RANK, dict_initialization=dict_initialization)
+
+
+@singledispatch
+def convenientplnpca(
+    counts,
+    covariates=None,
+    offsets=None,
+    offsets_formula=None,
+    dict_initialization=None,
+):
+    return PLNPCA(
+        counts,
+        covariates,
+        offsets,
+        offsets_formula,
+        dict_of_dict_initialization=dict_initialization,
+        ranks=RANKS,
+    )
+
+
+@convenientplnpca.register(str)
+def _(formula, data, offsets_formula, dict_initialization=None):
+    return PLNPCA(
+        formula,
+        data,
+        offsets_formula,
+        ranks=RANKS,
+        dict_of_dict_initialization=dict_initialization,
+    )
+
+
+params = [PLN, convenient_plnpca, convenientplnpca]
+dict_fixtures = {}
+
+
+@pytest.fixture(params=params)
+def simulated_pln_0cov_array(request):
+    cls = request.param
+    pln_full = cls(counts_sim_0cov, covariates_sim_0cov, offsets_sim_0cov)
+    return pln_full
+
+
+@pytest.fixture
+def simulated_fitted_pln_0cov_array(simulated_pln_0cov_array):
+    simulated_pln_0cov_array.fit()
+    return simulated_pln_0cov_array
+
+
+@pytest.fixture(params=params)
+def simulated_pln_0cov_formula(request):
+    cls = request.param
+    pln_full = cls("counts ~ 0", data_sim_0cov)
+    return pln_full
+
+
+@pytest.fixture
+def simulated_fitted_pln_0cov_formula(simulated_pln_0cov_formula):
+    simulated_pln_0cov_formula.fit()
+    return simulated_pln_0cov_formula
+
+
+@pytest.fixture
+def simulated_loaded_pln_0cov_formula(simulated_fitted_pln_0cov_formula):
+    simulated_fitted_pln_0cov_formula.save()
+    path = simulated_fitted_pln_0cov_formula.model_path
+    name = simulated_fitted_pln_0cov_formula.NAME
+    if name == "PLN" or name == "_PLNPCA":
+        init = load_model(path)
+    if name == "PLNPCA":
+        init = load_plnpca(path)
+    new = simulated_loaded_pln_0cov_formula.get_class(
+        "counts ~0", data_sim_0cov, dict_initialization=init
+    )
+    return new
+
+
+@pytest.fixture
+def simulated_loaded_pln_0cov_array(simulated_fitted_pln_0cov_array):
+    simulated_fitted_pln_0cov_array.save()
+    path = simulated_fitted_pln_0cov_array.model_path
+    name = simulated_fitted_pln_0cov_array.NAME
+    if name == "PLN" or name == "_PLNPCA":
+        init = load_model(path)
+    if name == "PLNPCA":
+        init = load_plnpca(path)
+    new = simulated_fitted_pln_0cov_array.get_class(
+        counts_sim_0cov,
+        covariates_sim_0cov,
+        offsets_sim_0cov,
+        dict_initialization=init,
+    )
+    return new
+
+
+sim_pln_0cov_instance = [
+    "simulated_pln_0cov_array",
+    "simulated_pln_0cov_formula",
+]
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "sim_pln_0cov_instance", sim_pln_0cov_instance
+)
+
+sim_pln_0cov_fitted = [
+    "simulated_fitted_pln_0cov_array",
+    "simulated_fitted_pln_0cov_formula",
+]
+
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "sim_pln_0cov_fitted", sim_pln_0cov_fitted
+)
+
+sim_pln_0cov_loaded = [
+    "simulated_loaded_pln_0cov_array",
+    "simulated_loaded_pln_0cov_formula",
+]
+
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "sim_pln_0cov_loaded", sim_pln_0cov_loaded
+)
+
+sim_pln_0cov = sim_pln_0cov_instance + sim_pln_0cov_fitted + sim_pln_0cov_loaded
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln_0cov", sim_pln_0cov)
+
+
+@pytest.fixture(params=params)
+def simulated_pln_2cov_array(request):
+    cls = request.param
+    pln_full = cls(counts_sim_2cov, covariates_sim_2cov, offsets_sim_2cov)
+    return pln_full
+
+
+@pytest.fixture
+def simulated_fitted_pln_2cov_array(simulated_pln_2cov_array):
+    simulated_pln_2cov_array.fit()
+    return simulated_pln_2cov_array
+
+
+@pytest.fixture(params=params)
+def simulated_pln_2cov_formula():
+    pln_full = cls("counts ~ 0 + covariates", data_sim_2cov)
+    return pln_full
+
+
+@pytest.fixture
+def simulated_fitted_pln_2cov_formula(simulated_pln_2cov_formula):
+    simulated_pln_2cov_formula.fit()
+    return simulated_pln_2cov_formula
+
+
+@pytest.fixture
+def simulated_loaded_pln_2cov_formula(simulated_fitted_pln_2cov_formula):
+    simulated_fitted_pln_2cov_formula.save()
+    path = simulated_fitted_pln_2cov_formula.model_path
+    name = simulated_fitted_pln_2cov_formula.NAME
+    if name == "PLN":
+        init = load_model(path)
+    if name == "PLNPCA":
+        init = load_plnpca(path)
+    new = simulated_fitted_pln_2cov_formula.get_class(
+        "counts ~1", data_sim_2cov, dict_initialization=init
+    )
+    return new
+
+
+@pytest.fixture
+def simulated_loaded_pln_2cov_array(simulated_fitted_pln_2cov_array):
+    simulated_fitted_pln_2cov_array.save()
+    path = simulated_fitted_pln_2cov_array.model_path
+    name = simulated_fitted_pln_2cov_array.NAME
+    if name == "PLN" or name == "_PLNPCA":
+        init = load_model(path)
+    if name == "PLNPCA":
+        init = load_model(path)
+    new = simulated_fitted_pln_2cov_array.get_class(
+        counts_sim_2cov,
+        covariates_sim_2cov,
+        offsets_sim_2cov,
+        dict_initialization=init,
+    )
+    return new
+
+
+sim_pln_2cov_instance = [
+    "simulated_pln_2cov_array",
+    "simulated_pln_2cov_formula",
+]
+
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "sim_pln_2cov_instance", sim_pln_2cov_instance
+)
+
+sim_pln_2cov_fitted = [
+    "simulated_fitted_pln_2cov_array",
+    "simulated_fitted_pln_2cov_formula",
+]
+
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "sim_pln_2cov_fitted", sim_pln_2cov_fitted
+)
+
+sim_pln_2cov_loaded = [
+    "simulated_loaded_pln_2cov_array",
+    "simulated_loaded_pln_2cov_formula",
+]
+
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "sim_pln_2cov_loaded", sim_pln_2cov_loaded
+)
+
+sim_pln_2cov = sim_pln_2cov_instance + sim_pln_2cov_fitted + sim_pln_2cov_loaded
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln_2cov", sim_pln_2cov)
+
+
+@pytest.fixture(params=params)
+def real_pln_intercept_array(request):
+    cls = request.param
+    pln_full = cls(counts_real)
+    return pln_full
+
+
+@pytest.fixture
+def real_fitted_pln_intercept_array(real_pln_intercept_array):
+    real_pln_intercept_array.fit()
+    return real_pln_intercept_array
+
+
+@pytest.fixture(params=params)
+def real_pln_intercept_formula(request):
+    cls = request.param
+    pln_full = cls("counts ~ 1", data_real)
+    return pln_full
+
+
+@pytest.fixture
+def real_fitted_pln_intercept_formula(real_pln_intercept_formula):
+    real_pln_intercept_formula.fit()
+    return real_pln_intercept_formula
+
+
+@pytest.fixture
+def real_loaded_pln_intercept_formula(real_fitted_pln_intercept_formula):
+    real_fitted_pln_intercept_formula.save()
+    path = real_fitted_pln_intercept_formula.model_path
+    name = real_fitted_pln_intercept_formula.NAME
+    if name == "PLN" or name == "_PLNPCA":
+        init = load_model(path)
+    if name == "PLNPCA":
+        init = load_plnpca(path)
+    new = real_fitted_pln_intercept_formula.get_class(
+        "counts~ 1", data_real, dict_initialization=init
+    )
+    return new
+
+
+@pytest.fixture
+def real_loaded_pln_intercept_array(real_fitted_pln_intercept_array):
+    real_fitted_pln_intercept_array.save()
+    path = real_fitted_pln_intercept_array.model_path
+    name = real_fitted_pln_intercept_array.NAME
+    if name == "PLN" or name == "_PLNPCA":
+        init = load_model(path)
+    if name == "PLNPCA":
+        init = load_plnpca(path)
+    new = real_fitted_pln_intercept_array.get_class(
+        counts_real, dict_initialization=init
+    )
+    return new
+
+
+real_pln_instance = [
+    "real_pln_intercept_array",
+    "real_pln_intercept_formula",
+]
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "real_pln_instance", real_pln_instance
+)
+
+real_pln_fitted = [
+    "real_fitted_pln_intercept_array",
+    "real_fitted_pln_intercept_formula",
+]
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "real_pln_fitted", real_pln_fitted
+)
+
+real_pln_loaded = [
+    "real_loaded_pln_intercept_array",
+    "real_loaded_pln_intercept_formula",
+]
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "real_pln_loaded", real_pln_loaded
+)
+
+sim_loaded_pln = sim_pln_0cov_loaded + sim_pln_2cov_loaded
+
+loaded_pln = real_pln_loaded + sim_loaded_pln
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "loaded_pln", loaded_pln)
+
+simulated_pln_fitted = sim_pln_0cov_fitted + sim_pln_2cov_fitted
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "simulated_pln_fitted", simulated_pln_fitted
+)
+
+fitted_pln = real_pln_fitted + simulated_pln_fitted
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "fitted_pln", fitted_pln)
+
+loaded_and_fitted_pln = fitted_pln + loaded_pln
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "loaded_and_fitted_pln", loaded_and_fitted_pln
+)
+
+real_pln = real_pln_instance + real_pln_fitted + real_pln_loaded
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "real_pln", real_pln)
+
+sim_pln = sim_pln_2cov + sim_pln_0cov
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln", sim_pln)
+
+all_pln = real_pln + sim_pln
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "all_pln", all_pln)
+
+
+for string_fixture in all_pln:
+    dict_fixtures = add_fixture_to_dict(dict_fixtures, string_fixture)
+
+pytest_plugins = [
+    fixture_file.replace("/", ".").replace(".py", "")
+    for fixture_file in glob.glob("src/**/tests/fixtures/[!__]*.py", recursive=True)
+]
diff --git a/tests/test_common.py b/tests/test_common.py
index 6fb1cd33..c9bd23ab 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -1,191 +1,51 @@
+import os
+
 import torch
 import numpy as np
 import pandas as pd
-
-from pyPLNmodels.models import PLN, _PLNPCA
-from pyPLNmodels import get_simulated_count_data, get_real_count_data
-from tests.utils import MSE
-
 import pytest
 from pytest_lazyfixture import lazy_fixture as lf
-import os
-
-(
-    counts_sim,
-    covariates_sim,
-    offsets_sim,
-    true_covariance,
-    true_coef,
-) = get_simulated_count_data(return_true_param=True, nb_cov=2)
-
-
-counts_real = get_real_count_data()
-rank = 8
-
-
-@pytest.fixture
-def instance_pln_full():
-    pln_full = PLN()
-    return pln_full
-
-
-@pytest.fixture
-def instance__plnpca():
-    plnpca = _PLNPCA(rank=rank)
-    return plnpca
 
+from pyPLNmodels.models import PLN, _PLNPCA
+from tests.utils import MSE
 
-@pytest.fixture
-def simulated_fitted_pln_full():
-    pln_full = PLN()
-    pln_full.fit(counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim)
-    return pln_full
+# from tests.import_fixtures import get_dict_fixtures
 
+from tests.conftest import dict_fixtures
 
-@pytest.fixture
-def simulated_fitted__plnpca():
-    plnpca = _PLNPCA(rank=rank)
-    plnpca.fit(counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim)
-    return plnpca
 
+# dict_fixtures_pln_full = dict_fixtures_models[0]
+# dict_fixtures_plnpca = dict_fixtures_models[1]
+# dict_fixturesplnpca = dict_fixtures_models[2]
 
-@pytest.fixture
-def loaded_simulated_pln_full(simulated_fitted_pln_full):
-    simulated_fitted_pln_full.save()
-    loaded_pln_full = PLN()
-    loaded_pln_full.load()
-    return loaded_pln_full
 
+def get_pln_and_plncpca_fixtures(key):
+    return dict_fixtures_pln_full[key] + dict_fixtures_plnpca[key]
 
-@pytest.fixture
-def loaded_refit_simulated_pln_full(loaded_simulated_pln_full):
-    loaded_simulated_pln_full.fit(
-        counts=counts_sim,
-        covariates=covariates_sim,
-        offsets=offsets_sim,
-        keep_going=True,
-    )
-    return loaded_simulated_pln_full
 
+def get_pca_fixtures(key):
+    return dict_fixtures_plnpca[key] + dict_fixturesplnpca[key]
 
-@pytest.fixture
-def loaded_simulated__plnpca(simulated_fitted__plnpca):
-    simulated_fitted__plnpca.save()
-    loaded_pln_full = _PLNPCA(rank=rank)
-    loaded_pln_full.load()
-    return loaded_pln_full
 
-
-@pytest.fixture
-def loaded_refit_simulated__plnpca(loaded_simulated__plnpca):
-    loaded_simulated__plnpca.fit(
-        counts=counts_sim,
-        covariates=covariates_sim,
-        offsets=offsets_sim,
-        keep_going=True,
-    )
-    return loaded_simulated__plnpca
-
-
-@pytest.fixture
-def real_fitted_pln_full():
-    pln_full = PLN()
-    pln_full.fit(counts=counts_real)
-    return pln_full
-
-
-@pytest.fixture
-def loaded_real_pln_full(real_fitted_pln_full):
-    real_fitted_pln_full.save()
-    loaded_pln_full = PLN()
-    loaded_pln_full.load()
-    return loaded_pln_full
-
-
-@pytest.fixture
-def loaded_refit_real_pln_full(loaded_real_pln_full):
-    loaded_real_pln_full.fit(counts=counts_real, keep_going=True)
-    return loaded_real_pln_full
-
-
-@pytest.fixture
-def real_fitted__plnpca():
-    plnpca = _PLNPCA(rank=rank)
-    plnpca.fit(counts=counts_real)
-    return plnpca
-
-
-@pytest.fixture
-def loaded_real__plnpca(real_fitted__plnpca):
-    real_fitted__plnpca.save()
-    loaded_plnpca = _PLNPCA(rank=rank)
-    loaded_plnpca.load()
-    return loaded_plnpca
-
-
-@pytest.fixture
-def loaded_refit_real__plnpca(loaded_real__plnpca):
-    loaded_real__plnpca.fit(counts=counts_real, keep_going=True)
-    return loaded_real__plnpca
-
-
-real_pln_full = [
-    lf("real_fitted_pln_full"),
-    lf("loaded_real_pln_full"),
-    lf("loaded_refit_real_pln_full"),
-]
-real__plnpca = [
-    lf("real_fitted__plnpca"),
-    lf("loaded_real__plnpca"),
-    lf("loaded_refit_real__plnpca"),
-]
-simulated_pln_full = [
-    lf("simulated_fitted_pln_full"),
-    lf("loaded_simulated_pln_full"),
-    lf("loaded_refit_simulated_pln_full"),
-]
-simulated__plnpca = [
-    lf("simulated_fitted__plnpca"),
-    lf("loaded_simulated__plnpca"),
-    lf("loaded_refit_simulated__plnpca"),
-]
-
-loaded_sim_pln = [
-    lf("loaded_simulated__plnpca"),
-    lf("loaded_simulated_pln_full"),
-    lf("loaded_refit_simulated_pln_full"),
-    lf("loaded_refit_simulated_pln_full"),
-]
-
-
-@pytest.mark.parametrize("loaded", loaded_sim_pln)
-def test_refit_not_keep_going(loaded):
-    loaded.fit(
-        counts=counts_sim,
-        covariates=covariates_sim,
-        offsets=offsets_sim,
-        keep_going=False,
+def get_all_fixtures(key):
+    return (
+        dict_fixtures_plnpca[key]
+        + dict_fixtures_pln_full[key]
+        + dict_fixturesplnpca[key]
     )
 
 
-all_instances = [lf("instance__plnpca"), lf("instance_pln_full")]
-
-all_fitted__plnpca = simulated__plnpca + real__plnpca
-all_fitted_pln_full = simulated_pln_full + real_pln_full
-
-simulated_any_pln = simulated__plnpca + simulated_pln_full
-real_any_pln = real_pln_full + real__plnpca
-all_fitted_models = simulated_any_pln + real_any_pln
-
-
-@pytest.mark.parametrize("any_pln", all_fitted_models)
+# @pytest.mark.parametrize("any_pln", [dict_fixtures["simulated_pln_0cov_array"]])
+@pytest.mark.parametrize("any_pln", dict_fixtures["loaded_and_fitted_pln"])
 def test_properties(any_pln):
-    assert hasattr(any_pln, "latent_variables")
-    assert hasattr(any_pln, "model_parameters")
-    assert hasattr(any_pln, "latent_parameters")
-    assert hasattr(any_pln, "optim_parameters")
+    if any_pln.NAME in ("PLN", "_PLNPCA"):
+        assert hasattr(any_pln, "latent_variables")
+        assert hasattr(any_pln, "model_parameters")
+        assert hasattr(any_pln, "latent_parameters")
+        assert hasattr(any_pln, "optim_parameters")
 
 
+"""
 @pytest.mark.parametrize("any_pln", all_fitted_models)
 def test_show_coef_transform_covariance_pcaprojected(any_pln):
     any_pln.show()
@@ -257,8 +117,8 @@ def test_find_right_coef(sim_pln):
     assert mse_coef < 0.1
 
 
-def test_number_of_iterations_pln_full(simulated_fitted_pln_full):
-    nb_iterations = len(simulated_fitted_pln_full.elbos_list)
+def test_number_of_iterations_pln_full(simulated_fitted_pln_full_0cov):
+    nb_iterations = len(simulated_fitted_pln_full_0cov.elbos_list)
     assert 50 < nb_iterations < 300
 
 
@@ -273,21 +133,21 @@ def test_computable_elbopca(instance__plnpca, simulated_fitted__plnpca):
     instance__plnpca.compute_elbo()
 
 
-def test_computable_elbo_full(instance_pln_full, simulated_fitted_pln_full):
-    instance_pln_full.counts = simulated_fitted_pln_full.counts
-    instance_pln_full.covariates = simulated_fitted_pln_full.covariates
-    instance_pln_full.offsets = simulated_fitted_pln_full.offsets
-    instance_pln_full.latent_mean = simulated_fitted_pln_full.latent_mean
-    instance_pln_full.latent_var = simulated_fitted_pln_full.latent_var
-    instance_pln_full.covariance = simulated_fitted_pln_full.covariance
-    instance_pln_full.coef = simulated_fitted_pln_full.coef
+def test_computable_elbo_full(instance_pln_full, simulated_fitted_pln_full_0cov):
+    instance_pln_full.counts = simulated_fitted_pln_full_0cov.counts
+    instance_pln_full.covariates = simulated_fitted_pln_full_0cov.covariates
+    instance_pln_full.offsets = simulated_fitted_pln_full_0cov.offsets
+    instance_pln_full.latent_mean = simulated_fitted_pln_full_0cov.latent_mean
+    instance_pln_full.latent_var = simulated_fitted_pln_full_0cov.latent_var
+    instance_pln_full.covariance = simulated_fitted_pln_full_0cov.covariance
+    instance_pln_full.coef = simulated_fitted_pln_full_0cov.coef
     instance_pln_full.compute_elbo()
 
 
-def test_fail_count_setter(simulated_fitted_pln_full):
+def test_fail_count_setter(simulated_fitted_pln_full_0cov):
     wrong_counts = torch.randint(size=(10, 5), low=0, high=10)
     with pytest.raises(Exception):
-        simulated_fitted_pln_full.counts = wrong_counts
+        simulated_fitted_pln_full_0cov.counts = wrong_counts
 
 
 @pytest.mark.parametrize("any_pln", all_fitted_models)
@@ -334,3 +194,6 @@ def test_wrong_rank():
     instance = _PLNPCA(counts_sim.shape[1] + 1)
     with pytest.warns(UserWarning):
         instance.fit(counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim)
+
+
+"""
-- 
GitLab


From 83a0d6be28671dc3d70c4912136d92703a5b8e77 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Tue, 2 May 2023 18:51:05 +0200
Subject: [PATCH 15/24] import_fixtures file is wrong.

---
 tests/import_data.py     |  39 ++++++
 tests/import_fixtures.py | 285 +++++++++++++++++++++++++++++++++++++++
 tests/test_pln_full.py   |   9 ++
 3 files changed, 333 insertions(+)
 create mode 100644 tests/import_data.py
 create mode 100644 tests/import_fixtures.py
 create mode 100644 tests/test_pln_full.py

diff --git a/tests/import_data.py b/tests/import_data.py
new file mode 100644
index 00000000..80e613ba
--- /dev/null
+++ b/tests/import_data.py
@@ -0,0 +1,39 @@
+import os
+
+from pyPLNmodels import (
+    get_simulated_count_data,
+    get_real_count_data,
+)
+
+
+(
+    counts_sim_0cov,
+    covariates_sim_0cov,
+    offsets_sim_0cov,
+    true_covariance_0cov,
+    true_coef_0cov,
+) = get_simulated_count_data(return_true_param=True, nb_cov=0)
+(
+    counts_sim_2cov,
+    covariates_sim_2cov,
+    offsets_sim_2cov,
+    true_covariance_2cov,
+    true_coef_2cov,
+) = get_simulated_count_data(return_true_param=True, nb_cov=2)
+
+data_sim_0cov = {
+    "counts": counts_sim_0cov,
+    "covariates": covariates_sim_0cov,
+    "offsets": offsets_sim_0cov,
+}
+true_sim_0cov = {"Sigma": true_covariance_0cov, "beta": true_coef_0cov}
+true_sim_2cov = {"Sigma": true_covariance_2cov, "beta": true_coef_2cov}
+
+
+data_sim_2cov = {
+    "counts": counts_sim_2cov,
+    "covariates": covariates_sim_2cov,
+    "offsets": offsets_sim_2cov,
+}
+counts_real = get_real_count_data()
+data_real = {"counts": counts_real}
diff --git a/tests/import_fixtures.py b/tests/import_fixtures.py
new file mode 100644
index 00000000..e2d9b1ef
--- /dev/null
+++ b/tests/import_fixtures.py
@@ -0,0 +1,285 @@
+import pytest
+from pytest_lazyfixture import lazy_fixture as lf
+from pyPLNmodels import load_model
+from tests.import_data import (
+    data_sim_0cov,
+    data_sim_2cov,
+    data_real,
+)
+
+counts_sim_0cov = data_sim_0cov["counts"]
+covariates_sim_0cov = data_sim_0cov["covariates"]
+offsets_sim_0cov = data_sim_0cov["offsets"]
+
+counts_sim_2cov = data_sim_2cov["counts"]
+covariates_sim_2cov = data_sim_2cov["covariates"]
+offsets_sim_2cov = data_sim_2cov["offsets"]
+
+counts_real = data_real["counts"]
+
+
+def add_fixture_to_dict(my_dict, string_fixture):
+    my_dict[string_fixture] = lf(string_fixture)
+    return my_dict
+
+
+def add_list_of_fixture_to_dict(
+    my_dict, name_of_list_of_fixtures, list_of_string_fixtures
+):
+    my_dict[name_of_list_of_fixtures] = []
+    for string_fixture in list_of_string_fixtures:
+        my_dict[name_of_list_of_fixtures].append(lf(string_fixture))
+    return my_dict
+
+
+def get_dict_fixtures(PLNor_PLNPCA):
+    dict_fixtures = {}
+
+    @pytest.fixture
+    def simulated_pln_0cov_array():
+        pln_full = PLNor_PLNPCA(counts_sim_0cov, covariates_sim_0cov, offsets_sim_0cov)
+        return pln_full
+
+    @pytest.fixture
+    def simulated_fitted_pln_0cov_array(simulated_pln_0cov_array):
+        simulated_pln_0cov_array.fit()
+        return simulated_pln_0cov_array
+
+    @pytest.fixture
+    def simulated_pln_0cov_formula():
+        pln_full = PLNor_PLNPCA("counts ~ 0", data_sim_0cov)
+        return pln_full
+
+    @pytest.fixture
+    def simulated_fitted_pln_0cov_formula(simulated_pln_0cov_formula):
+        simulated_pln_0cov_array.fit()
+        return simulated_pln_0cov_formula
+
+    @pytest.fixture
+    def simulated_loaded_pln_0cov_formula(simulated_fitted_pln_0cov_formula):
+        simulated_fitted_pln_0cov_formula.save()
+        if simulated_fitted_pln_0cov_formula.NAME == "PLN":
+            init = load_model("PLN_nbcov_0")
+        if simulated_fitted_pln_0cov_formula.NAME == "_PLNPCA":
+            init = load_model(f"PLNPCA_rank_{simulated_fitted_pln_0cov_formula.rank}")
+        new = PLNor_PLNPCA("counts ~0", data_sim_0cov, dict_initialization=init)
+        return new
+
+    @pytest.fixture
+    def simulated_loaded_pln_0cov_array(simulated_fitted_pln_0cov_array):
+        simulated_fitted_pln_0cov_array.save()
+        if simulated_fitted_pln_0cov_array.NAME == "PLN":
+            init = load_model("PLN_nbcov_0")
+        if simulated_fitted_pln_0cov_array.NAME == "_PLNPCA":
+            init = load_model(f"PLNPCA_rank_{simulated_fitted_pln_0cov_array.rank}")
+        new = PLNor_PLNPCA(
+            counts_sim_0cov,
+            covariates_sim_0cov,
+            offsets_sim_0cov,
+            dict_initialization=init,
+        )
+        return new
+
+    sim_pln_0cov_instance = [
+        "simulated_pln_0cov_array",
+        "simulated_pln_0cov_formula",
+    ]
+    dict_fixtures = add_list_of_fixture_to_dict(
+        dict_fixtures, "sim_pln_0cov_instance", sim_pln_0cov_instance
+    )
+
+    sim_pln_0cov_fitted = [
+        "simulated_fitted_pln_0cov_array",
+        "simulated_fitted_pln_0cov_formula",
+    ]
+
+    dict_fixtures = add_list_of_fixture_to_dict(
+        dict_fixtures, "sim_pln_0cov_fitted", sim_pln_0cov_fitted
+    )
+
+    sim_pln_0cov_loaded = [
+        "simulated_loaded_pln_0cov_array",
+        "simulated_loaded_pln_0cov_formula",
+    ]
+
+    dict_fixtures = add_list_of_fixture_to_dict(
+        dict_fixtures, "sim_pln_0cov_loaded", sim_pln_0cov_loaded
+    )
+
+    sim_pln_0cov = sim_pln_0cov_instance + sim_pln_0cov_fitted + sim_pln_0cov_loaded
+    dict_fixtures = add_list_of_fixture_to_dict(
+        dict_fixtures, "sim_pln_0cov", sim_pln_0cov
+    )
+
+    @pytest.fixture
+    def simulated_pln_2cov_array():
+        pln_full = PLNor_PLNPCA(counts_sim_2cov, covariates_sim_2cov, offsets_sim_2cov)
+        return pln_full
+
+    @pytest.fixture
+    def simulated_fitted_pln_2cov_array(simulated_pln_2cov_array):
+        simulated_pln_2cov_array.fit()
+        return simulated_pln_2cov_array
+
+    @pytest.fixture
+    def simulated_pln_2cov_formula():
+        pln_full = PLNor_PLNPCA("counts ~ 0 + covariates", data_sim_2cov)
+        return pln_full
+
+    @pytest.fixture
+    def simulated_fitted_pln_2cov_formula(simulated_pln_2cov_formula):
+        simulated_pln_2cov_formula.fit()
+        return simulated_pln_2cov_formula
+
+    @pytest.fixture
+    def simulated_loaded_pln_2cov_formula(simulated_fitted_pln_2cov_formula):
+        simulated_fitted_pln_2cov_formula.save()
+        if simulated_fitted_pln_2cov_formula.NAME == "PLN":
+            init = load_model("PLN_nbcov_2")
+        if simulated_fitted_pln_2cov_formula.NAME == "_PLNPCA":
+            init = load_model(f"PLNPCA_rank_{simulated_fitted_pln_2cov_formula.rank}")
+        new = PLNor_PLNPCA("counts ~2", data_sim_2cov, dict_initialization=init)
+        return new
+
+    @pytest.fixture
+    def simulated_loaded_pln_2cov_array(simulated_fitted_pln_2cov_array):
+        simulated_fitted_pln_2cov_array.save()
+        if simulated_fitted_pln_2cov_array.NAME == "PLN":
+            init = load_model("PLN_nbcov_2")
+        if simulated_fitted_pln_2cov_array.NAME == "_PLNPCA":
+            init = load_model(f"PLNPCA_rank_{simulated_fitted_pln_2cov_array.rank}")
+        new = PLNor_PLNPCA(
+            counts_sim_2cov,
+            covariates_sim_2cov,
+            offsets_sim_2cov,
+            dict_initialization=init,
+        )
+        return new
+
+    sim_pln_2cov_instance = [
+        "simulated_pln_2cov_array",
+        "simulated_pln_2cov_formula",
+    ]
+
+    dict_fixtures = add_list_of_fixture_to_dict(
+        dict_fixtures, "sim_pln_2cov_instance", sim_pln_2cov_instance
+    )
+
+    sim_pln_2cov_fitted = [
+        "simulated_fitted_pln_2cov_array",
+        "simulated_fitted_pln_2cov_formula",
+    ]
+
+    dict_fixtures = add_list_of_fixture_to_dict(
+        dict_fixtures, "sim_pln_2cov_fitted", sim_pln_2cov_fitted
+    )
+
+    sim_pln_2cov_loaded = [
+        "simulated_loaded_pln_2cov_array",
+        "simulated_loaded_pln_2cov_formula",
+    ]
+
+    dict_fixtures = add_list_of_fixture_to_dict(
+        dict_fixtures, "sim_pln_2cov_loaded", sim_pln_2cov_loaded
+    )
+
+    sim_pln_2cov = sim_pln_2cov_instance + sim_pln_2cov_fitted + sim_pln_2cov_loaded
+    dict_fixtures = add_list_of_fixture_to_dict(
+        dict_fixtures, "sim_pln_2cov", sim_pln_2cov
+    )
+
+    @pytest.fixture
+    def real_pln_intercept_array():
+        pln_full = PLNor_PLNPCA(counts_real)
+        return pln_full
+
+    @pytest.fixture
+    def real_fitted_pln_intercept_array(real_pln_intercept_array):
+        real_pln_intercept_array.fit()
+        return real_pln_intercept_array
+
+    @pytest.fixture
+    def real_pln_intercept_formula():
+        pln_full = PLNor_PLNPCA("counts ~ 1", data_real)
+        return pln_full
+
+    @pytest.fixture
+    def real_fitted_pln_intercept_formula(real_pln_intercept_formula):
+        real_pln_intercept_formula.fit()
+        return real_pln_intercept_formula
+
+    @pytest.fixture
+    def real_loaded_pln_intercept_formula(real_fitted_pln_intercept_formula):
+        real_fitted_pln_intercept_formula.save()
+        if real_fitted_pln_intercept_formula.NAME == "PLN":
+            init = load_model("PLN_nbcov_2")
+        if real_fitted_pln_intercept_formula.NAME == "_PLNPCA":
+            init = load_model(f"PLNPCA_rank_{real_fitted_pln_intercept_formula.rank}")
+        new = PLNor_PLNPCA("counts ~2", data_real, dict_initialization=init)
+        return new
+
+    @pytest.fixture
+    def real_loaded_pln_intercept_array(real_fitted_pln_intercept_array):
+        real_fitted_pln_intercept_array.save()
+        if real_fitted_pln_intercept_array.NAME == "PLN":
+            init = load_model("PLN_nbcov_2")
+        if real_fitted_pln_intercept_array.NAME == "_PLNPCA":
+            init = load_model(f"PLNPCA_rank_{real_fitted_pln_intercept_array.rank}")
+        new = PLNor_PLNPCA(counts_real, dict_initialization=init)
+        return new
+
+    real_pln_instance = [
+        "real_pln_intercept_array",
+        "real_pln_intercept_formula",
+    ]
+    dict_fixtures = add_list_of_fixture_to_dict(
+        dict_fixtures, "real_pln_instance", real_pln_instance
+    )
+
+    real_pln_fitted = [
+        "real_fitted_pln_intercept_array",
+        "real_fitted_pln_intercept_formula",
+    ]
+    dict_fixtures = add_list_of_fixture_to_dict(
+        dict_fixtures, "real_pln_fitted", real_pln_fitted
+    )
+
+    real_pln_loaded = [
+        "real_loaded_pln_intercept_array",
+        "real_loaded_pln_intercept_formula",
+    ]
+    dict_fixtures = add_list_of_fixture_to_dict(
+        dict_fixtures, "real_pln_loaded", real_pln_loaded
+    )
+
+    sim_loaded_pln = sim_pln_0cov_loaded + sim_pln_2cov_loaded
+
+    loaded_pln = real_pln_loaded + sim_loaded_pln
+    dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "loaded_pln", loaded_pln)
+
+    simulated_pln_fitted = sim_pln_0cov_fitted + sim_pln_2cov_fitted
+    dict_fixtures = add_list_of_fixture_to_dict(
+        dict_fixtures, "simulated_pln_fitted", simulated_pln_fitted
+    )
+
+    fitted_pln = real_pln_fitted + simulated_pln_fitted
+    dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "fitted_pln", fitted_pln)
+
+    loaded_and_fitted_pln = fitted_pln + loaded_pln
+    dict_fixtures = add_list_of_fixture_to_dict(
+        dict_fixtures, "loaded_and_fitted_pln", loaded_and_fitted_pln
+    )
+
+    real_pln = real_pln_instance + real_pln_fitted + real_pln_loaded
+    dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "real_pln", real_pln)
+
+    sim_pln = sim_pln_2cov + sim_pln_0cov
+    dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln", sim_pln)
+
+    all_pln = real_pln + sim_pln
+    dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "all_pln", all_pln)
+
+    for string_fixture in all_pln:
+        dict_fixtures = add_fixture_to_dict(dict_fixtures, string_fixture)
+
+    return dict_fixtures
diff --git a/tests/test_pln_full.py b/tests/test_pln_full.py
new file mode 100644
index 00000000..078c5bb5
--- /dev/null
+++ b/tests/test_pln_full.py
@@ -0,0 +1,9 @@
+import torch
+
+from import_fixtures_and_data import get_dict_fixtures
+from pyPLNmodels import PLN
+
+
+df = get_dict_fixtures(PLN)
+for key, fixture in df.items():
+    print(len(fixture))
-- 
GitLab


From 62d951907eba4144d98ea7094e65e98a28d8a0cb Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 4 May 2023 18:31:14 +0200
Subject: [PATCH 16/24] changed the tests thanks to JB. Working fine. Now I
 need to pass them all. Need to change the elbos and the poiss_reg to handle
 the case where covariates is None (no covariates is given).

---
 pyPLNmodels/elbos.py |  6 +++++-
 tests/conftest.py    | 27 ++++++++++++++++++++++++---
 tests/test_common.py | 37 ++++++++++++++++++++++++++++++-------
 3 files changed, 59 insertions(+), 11 deletions(-)

diff --git a/pyPLNmodels/elbos.py b/pyPLNmodels/elbos.py
index 49f4dc02..8fb8b319 100644
--- a/pyPLNmodels/elbos.py
+++ b/pyPLNmodels/elbos.py
@@ -22,7 +22,11 @@ def elbo_pln(counts, covariates, offsets, latent_mean, latent_var, covariance, c
     n_samples, dim = counts.shape
     s_rond_s = torch.multiply(latent_var, latent_var)
     offsets_plus_m = offsets + latent_mean
-    m_minus_xb = latent_mean - torch.mm(covariates, coef)
+    if covariates is None:
+        XB = 0
+    else:
+        XB = covariates @ coef
+    m_minus_xb = latent_mean - XB
     d_plus_minus_xb2 = torch.diag(torch.sum(s_rond_s, dim=0)) + torch.mm(
         m_minus_xb.T, m_minus_xb
     )
diff --git a/tests/conftest.py b/tests/conftest.py
index 4f2d2970..a6d47ec0 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,6 +1,6 @@
 import sys
 import glob
-from functools import singledispatch
+from functools import singledispatch, singledispatchmethod
 
 import pytest
 from pytest_lazyfixture import lazy_fixture as lf
@@ -67,7 +67,7 @@ def convenient_plnpca(
 
 
 @convenient_plnpca.register(str)
-def _(formula, data, offsets_formula, dict_initialization=None):
+def _(formula, data, offsets_formula=None, dict_initialization=None):
     return _PLNPCA(formula, data, rank=RANK, dict_initialization=dict_initialization)
 
 
@@ -89,8 +89,16 @@ def convenientplnpca(
     )
 
 
+# class convenientplnpca(PLNPCA):
+#     @singledispatchmethod
+#     def __init__(self,counts, covariates=None, offsets= None, offsets_formula=None, dict_initialization=None):
+#         super().__init__(counts, covariates, offsets, offsets_formula, dict_initialization)
+
+#     def _(formula, data, offsets.)
+
+
 @convenientplnpca.register(str)
-def _(formula, data, offsets_formula, dict_initialization=None):
+def _(formula, data, offsets_formula=None, dict_initialization=None):
     return PLNPCA(
         formula,
         data,
@@ -100,11 +108,23 @@ def _(formula, data, offsets_formula, dict_initialization=None):
     )
 
 
+def cache(func):
+    dict_cache = {}
+
+    def new_func(request):
+        if request.param.__name__ not in dict_cache:
+            dict_cache[request.param.__name__] = func(request)
+        return dict_cache[request.param.__name__]
+
+    return new_func
+
+
 params = [PLN, convenient_plnpca, convenientplnpca]
 dict_fixtures = {}
 
 
 @pytest.fixture(params=params)
+@cache
 def simulated_pln_0cov_array(request):
     cls = request.param
     pln_full = cls(counts_sim_0cov, covariates_sim_0cov, offsets_sim_0cov)
@@ -298,6 +318,7 @@ def real_fitted_pln_intercept_array(real_pln_intercept_array):
 @pytest.fixture(params=params)
 def real_pln_intercept_formula(request):
     cls = request.param
+    print("cls:", cls)
     pln_full = cls("counts ~ 1", data_real)
     return pln_full
 
diff --git a/tests/test_common.py b/tests/test_common.py
index c9bd23ab..97f81357 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -1,4 +1,5 @@
 import os
+import functools
 
 import torch
 import numpy as np
@@ -35,15 +36,37 @@ def get_all_fixtures(key):
     )
 
 
+def filter_models(models_name):
+    def decorator(my_test):
+        @functools.wraps(my_test)
+        def new_test(**kwargs):
+            fixture = next(iter(kwargs.values()))
+            if type(fixture).__name__ not in models_name:
+                return None
+            return my_test(**kwargs)
+
+        return new_test
+
+    return decorator
+
+
+# @pytest.mark.parametrize("any_pln", [dict_fixtures["simulated_pln_0cov_array"]])
 # @pytest.mark.parametrize("any_pln", [dict_fixtures["simulated_pln_0cov_array"]])
-@pytest.mark.parametrize("any_pln", dict_fixtures["loaded_and_fitted_pln"])
-def test_properties(any_pln):
-    if any_pln.NAME in ("PLN", "_PLNPCA"):
-        assert hasattr(any_pln, "latent_variables")
-        assert hasattr(any_pln, "model_parameters")
-        assert hasattr(any_pln, "latent_parameters")
-        assert hasattr(any_pln, "optim_parameters")
+@filter_models(["PLN", "_PLNPCA"])
+def test_properties(simulated_fitted_pln_0cov_array):
+    assert hasattr(simulated_fitted_pln_0cov_array, "model_parameters")
+    assert hasattr(simulated_fitted_pln_0cov_array, "latent_parameters")
+    print("model_param", simulated_fitted_pln_0cov_array.model_parameters)
+    assert hasattr(simulated_fitted_pln_0cov_array, "latent_variables")
+    assert hasattr(simulated_fitted_pln_0cov_array, "optim_parameters")
+
+
+@pytest.mark.parametrize("any_pln", dict_fixtures["fitted_pln"])
+def test_print(any_pln):
+    print(any_pln)
+
 
+print("len :", len(dict_fixtures["all_pln"]))
 
 """
 @pytest.mark.parametrize("any_pln", all_fitted_models)
-- 
GitLab


From 3321da94296fd8433fc0cc73c6f869ee08d4f533 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Tue, 9 May 2023 12:59:08 +0200
Subject: [PATCH 17/24] continu to implement the tests

---
 pyPLNmodels/_closed_forms.py |  10 +-
 pyPLNmodels/_utils.py        |  41 +++--
 pyPLNmodels/elbos.py         |   8 +-
 pyPLNmodels/models.py        |  90 +++++++----
 tests/conftest.py            | 163 +++++++++-----------
 tests/import_fixtures.py     | 285 -----------------------------------
 tests/test_args.py           |  15 ++
 tests/test_common.py         | 108 ++++---------
 8 files changed, 211 insertions(+), 509 deletions(-)

diff --git a/pyPLNmodels/_closed_forms.py b/pyPLNmodels/_closed_forms.py
index 5e00c396..087bc386 100644
--- a/pyPLNmodels/_closed_forms.py
+++ b/pyPLNmodels/_closed_forms.py
@@ -3,7 +3,11 @@ import torch  # pylint:disable=[C0114]
 
 def closed_formula_covariance(covariates, latent_mean, latent_var, coef, n_samples):
     """Closed form for covariance for the M step for the noPCA model."""
-    m_moins_xb = latent_mean - torch.mm(covariates, coef)
+    if covariates is None:
+        XB = 0
+    else:
+        XB = covariates @ coef
+    m_moins_xb = latent_mean - XB
     closed = torch.mm(m_moins_xb.T, m_moins_xb)
     closed += torch.diag(torch.sum(torch.multiply(latent_var, latent_var), dim=0))
     return 1 / (n_samples) * closed
@@ -11,8 +15,8 @@ def closed_formula_covariance(covariates, latent_mean, latent_var, coef, n_sampl
 
 def closed_formula_coef(covariates, latent_mean):
     """Closed form for coef for the M step for the noPCA model."""
-    if torch.sum(torch.abs(covariates)) < 1e-15:
-        return torch.zeros(covariates.shape[1], latent_mean.shape[1])
+    if covariates is None:
+        return 0
     return torch.mm(
         torch.mm(torch.inverse(torch.mm(covariates.T, covariates)), covariates.T),
         latent_mean,
diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 3a1a9a7a..23a02fd2 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -238,6 +238,8 @@ def components_from_covariance(covariance, rank):
 
 
 def init_coef(counts, covariates, offsets):
+    if covariates is None:
+        return None
     poiss_reg = PoissonReg()
     poiss_reg.fit(counts, covariates, offsets)
     return poiss_reg.beta
@@ -274,8 +276,11 @@ def log_posterior(counts, covariates, offsets, posterior_mean, components, coef)
     components_posterior_mean = torch.matmul(
         components.unsqueeze(0), posterior_mean.unsqueeze(2)
     ).squeeze()
-
-    log_lambda = offsets + components_posterior_mean + covariates @ coef
+    if covariates is None:
+        XB = 0
+    else:
+        XB = covariates @ coef
+    log_lambda = offsets + components_posterior_mean + XB
     first_term = (
         -rank / 2 * math.log(2 * math.pi)
         - 1 / 2 * torch.norm(posterior_mean, dim=-1) ** 2
@@ -333,6 +338,8 @@ def init_S(counts, covariates, offsets, beta, C, M):
 
 
 def format_data(data):
+    if data is None:
+        return None
     if isinstance(data, pd.DataFrame):
         return torch.from_numpy(data.values).double().to(DEVICE)
     if isinstance(data, np.ndarray):
@@ -346,9 +353,7 @@ def format_data(data):
 
 def format_model_param(counts, covariates, offsets, offsets_formula):
     counts = format_data(counts)
-    if covariates is None:
-        covariates = torch.zeros(counts.shape[0], 1)
-    else:
+    if covariates is not None:
         covariates = format_data(covariates)
     if offsets is None:
         if offsets_formula == "logsum":
@@ -371,6 +376,7 @@ def remove_useless_intercepts(covariates):
     second_column = covariates[:, 1]
     diff = first_column - second_column
     if torch.sum(torch.abs(diff - diff[0])) == 0:
+        print("removing one")
         return covariates[:, 1:]
     return covariates
 
@@ -378,9 +384,10 @@ def remove_useless_intercepts(covariates):
 def check_data_shape(counts, covariates, offsets):
     n_counts, p_counts = counts.shape
     n_offsets, p_offsets = offsets.shape
-    n_cov, _ = covariates.shape
     check_two_dimensions_are_equal("counts", "offsets", n_counts, n_offsets, 0)
-    check_two_dimensions_are_equal("counts", "covariates", n_counts, n_cov, 0)
+    if covariates is not None:
+        n_cov, _ = covariates.shape
+        check_two_dimensions_are_equal("counts", "covariates", n_counts, n_cov, 0)
     check_two_dimensions_are_equal("counts", "offsets", p_counts, p_offsets, 1)
 
 
@@ -567,6 +574,7 @@ def check_dimensions_are_equal(tens1, tens2):
 
 
 def load_model(path_of_directory):
+    working_dict = os.getcwd()
     os.chdir(path_of_directory)
     all_files = os.listdir()
     data = {}
@@ -574,14 +582,13 @@ def load_model(path_of_directory):
         if len(filename) > 4:
             if filename[-4:] == ".csv":
                 parameter = filename[:-4]
-                # data[parameter] = pd.read_csv(filename, header=None).values
                 try:
                     data[parameter] = pd.read_csv(filename, header=None).values
                 except pd.errors.EmptyDataError as err:
                     print(
-                        f"Can t load {parameter} since empty. Standard initialization will be performed"
+                        f"Can't load {parameter} since empty. Standard initialization will be performed"
                     )
-    os.chdir("../")
+    os.chdir(working_dict)
     return data
 
 
@@ -590,6 +597,7 @@ def load_pln(path_of_directory):
 
 
 def load_plnpca(path_of_directory, ranks=None):
+    working_dict = os.getcwd()
     os.chdir(path_of_directory)
     if ranks is None:
         dirnames = os.listdir()
@@ -598,14 +606,14 @@ def load_plnpca(path_of_directory, ranks=None):
             try:
                 rank = int(dirname[-1])
             except ValueError:
-                print(
-                    f"Can t load the model {dirname}. End of {dirname} should be an int"
+                raise ValueError(
+                    f"Can't load the model {dirname}. End of {dirname} should be an int"
                 )
             ranks.append(rank)
     datas = {}
     for rank in ranks:
-        datas[rank] = load_model(f"PLNPCA_rank_{rank}")
-    os.chdir("../")
+        datas[rank] = load_model(f"_PLNPCA_rank_{rank}")
+    os.chdir(working_dict)
     return datas
 
 
@@ -622,8 +630,11 @@ def extract_data_from_formula(formula, data):
     dmatrix = dmatrices(formula, data=data)
     counts = dmatrix[0]
     covariates = dmatrix[1]
-    if len(covariates) > 0:
+    if covariates.size > 0:
+        pass
         covariates = remove_useless_intercepts(covariates)
+    else:
+        covariates = None
     offsets = data.get("offsets", None)
     return counts, covariates, offsets
 
diff --git a/pyPLNmodels/elbos.py b/pyPLNmodels/elbos.py
index 8fb8b319..528cba46 100644
--- a/pyPLNmodels/elbos.py
+++ b/pyPLNmodels/elbos.py
@@ -94,9 +94,11 @@ def elbo_plnpca(counts, covariates, offsets, latent_mean, latent_var, components
     """
     n_samples = counts.shape[0]
     rank = components.shape[1]
-    log_intensity = (
-        offsets + torch.mm(covariates, coef) + torch.mm(latent_mean, components.T)
-    )
+    if covariates is None:
+        XB = 0
+    else:
+        XB = covariates @ coef
+    log_intensity = offsets + XB + torch.mm(latent_mean, components.T)
     s_rond_s = torch.multiply(latent_var, latent_var)
     counts_log_intensity = torch.sum(torch.multiply(counts, log_intensity))
     minus_intensity_plus_s_rond_s_cct = torch.sum(
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index d01424ff..cdf11513 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -38,7 +38,6 @@ from ._utils import (
     check_dimensions_are_equal,
     check_right_rank,
     remove_useless_intercepts,
-    is_dict_of_dict,
     extract_data_from_formula,
     get_dict_initialization,
 )
@@ -95,9 +94,7 @@ class _PLN(ABC):
         self._fitted = False
         self.plotargs = PLNPlotArgs(self.WINDOW)
         if dict_initialization is not None:
-            for key, value in dict_initialization.items():
-                value = torch.from_numpy(dict_initialization[key])
-                setattr(self, key, value)
+            self.set_init_parameters(dict_initialization)
 
     @__init__.register(str)
     def _(
@@ -115,6 +112,14 @@ class _PLN(ABC):
         offsets = data.get("offsets", None)
         self.__init__(counts, covariates, offsets, offsets_formula, dict_initialization)
 
+    def set_init_parameters(self, dict_initialization):
+        if "coef" not in dict_initialization.keys():
+            print("No coef is initialized.")
+            self.coef = None
+        for key, array in dict_initialization.items():
+            array = format_data(array)
+            setattr(self, key, array)
+
     @property
     def fitted(self):
         return
@@ -133,12 +138,16 @@ class _PLN(ABC):
 
     @property
     def nb_cov(self):
+        if self.covariates is None:
+            return 0
         return self.covariates.shape[1]
 
     def smart_init_coef(self):
         self._coef = init_coef(self._counts, self._covariates, self._offsets)
 
     def random_init_coef(self):
+        if self.nb_cov == 0:
+            self._coef = None
         self._coef = torch.randn((self.nb_cov, self.dim), device=DEVICE)
 
     @abstractmethod
@@ -182,7 +191,7 @@ class _PLN(ABC):
         nb_max_iteration=50000,
         lr=0.01,
         class_optimizer=torch.optim.Rprop,
-        tol=1e-5,
+        tol=1e-3,
         do_smart_init=True,
         verbose=False,
     ):
@@ -427,7 +436,7 @@ class _PLN(ABC):
         return None
 
     def save(self, path_of_directory="./"):
-        path = f"{path_of_directory}/{self.model_path}/"
+        path = f"{path_of_directory}/{self.path_to_directory}{self.directory_name}"
         os.makedirs(path, exist_ok=True)
         for key, value in self.dict_parameters.items():
             filename = f"{path}/{key}.csv"
@@ -499,8 +508,13 @@ class _PLN(ABC):
         return self.covariance
 
     def predict(self, covariates=None):
+        if covariates is not None and self.nb_cov == 0:
+            raise AttributeError("No covariates in the model, can't predict")
         if covariates is None:
-            return self.coef[0, :]
+            if self.covariates is None:
+                print("No covariates in the model.")
+                return None
+            return self.covariates @ self.coef
         if covariates.shape[-1] != self.nb_cov:
             error_string = f"X has wrong shape ({covariates.shape}).Should"
             error_string += f" be ({self.n_samples, self.nb_cov})."
@@ -508,18 +522,19 @@ class _PLN(ABC):
         return covariates @ self.coef
 
     @property
-    def model_path(self):
+    def directory_name(self):
         return f"{self.NAME}_nbcov_{self.nb_cov}_dim_{self.dim}"
 
+    @property
+    def path_to_directory(self):
+        return ""
+
 
 # need to do a good init for M and S
 class PLN(_PLN):
     NAME = "PLN"
     coef: torch.Tensor
 
-    def get_class(self):
-        return PLN
-
     @property
     def description(self):
         return "full covariance model."
@@ -624,9 +639,6 @@ class PLN(_PLN):
 class PLNPCA:
     NAME = "PLNPCA"
 
-    def get_class(self):
-        return PLNPCA
-
     @singledispatchmethod
     def __init__(
         self,
@@ -634,7 +646,7 @@ class PLNPCA:
         covariates=None,
         offsets=None,
         offsets_formula="logsum",
-        ranks=range(1, 5),
+        ranks=range(3, 5),
         dict_of_dict_initialization=None,
     ):
         self.init_data(counts, covariates, offsets, offsets_formula)
@@ -653,7 +665,7 @@ class PLNPCA:
         formula: str,
         data: dict,
         offsets_formula="logsum",
-        ranks=range(1, 5),
+        ranks=range(3, 5),
         dict_of_dict_initialization=None,
     ):
         counts, covariates, offsets = extract_data_from_formula(formula, data)
@@ -666,6 +678,18 @@ class PLNPCA:
             dict_of_dict_initialization,
         )
 
+    @property
+    def covariates(self):
+        return self.models[0].covariates
+
+    @property
+    def counts(self):
+        return self.models[0].counts
+
+    @property
+    def offsets(self):
+        return self.models[0].offsets
+
     def init_models(self, ranks, dict_of_dict_initialization):
         if isinstance(ranks, (Iterable, np.ndarray)):
             self.models = []
@@ -825,10 +849,10 @@ class PLNPCA:
             ranks = self.ranks
         for model in self.models:
             if model.rank in ranks:
-                model.save(f"{path_of_directory}/{self.model_path}")
+                model.save(path_of_directory)
 
     @property
-    def model_path(self):
+    def directory_name(self):
         return f"{self.NAME}_nbcov_{self.nb_cov}_dim_{self.dim}"
 
     @property
@@ -875,18 +899,16 @@ class PLNPCA:
 
 # Here, setting the value for each key in dict_parameters
 class _PLNPCA(_PLN):
-    NAME = "PLNPCA"
+    NAME = "_PLNPCA"
     _components: torch.Tensor
 
-    def get_class(self):
-        return _PLNPCA
-
     @singledispatchmethod
     def __init__(self, counts, covariates, offsets, rank, dict_initialization=None):
         self._rank = rank
-        self._counts = counts
-        self._covariates = covariates
-        self._offsets = offsets
+        self._counts, self._covariates, self._offsets = format_model_param(
+            counts, covariates, offsets, None
+        )
+        check_data_shape(self._counts, self._covariates, self._offsets)
         self.check_if_rank_is_too_high()
         if dict_initialization is not None:
             self.set_init_parameters(dict_initialization)
@@ -894,14 +916,9 @@ class _PLNPCA(_PLN):
         self.plotargs = PLNPlotArgs(self.WINDOW)
 
     @__init__.register(str)
-    def _(self, formula, data, dict_initialization):
+    def _(self, formula, data, rank, dict_initialization):
         counts, covariates, offsets = extract_data_from_formula(formula, data)
-        self.__init__(counts, covariates, offsets, None, dict_initialization)
-
-    def set_init_parameters(self, dict_parameters):
-        for key, array in dict_parameters.items():
-            array = format_data(array)
-            setattr(self, key, array)
+        self.__init__(counts, covariates, offsets, rank, dict_initialization)
 
     def check_if_rank_is_too_high(self):
         if self.dim < self.rank:
@@ -914,8 +931,13 @@ class _PLNPCA(_PLN):
             self._rank = self.dim
 
     @property
-    def model_path(self):
+    def directory_name(self):
         return f"{self.NAME}_rank_{self._rank}"
+        # return f"PLNPCA_nbcov_{self.nb_cov}_dim_{self.dim}/{self.NAME}_rank_{self._rank}"
+
+    @property
+    def path_to_directory(self):
+        return f"PLNPCA_nbcov_{self.nb_cov}_dim_{self.dim}/"
 
     @property
     def rank(self):
@@ -968,6 +990,8 @@ class _PLNPCA(_PLN):
 
     @property
     def list_of_parameters_needing_gradient(self):
+        if self._coef is None:
+            return [self._components, self._latent_mean, self._latent_var]
         return [self._components, self._coef, self._latent_mean, self._latent_var]
 
     def compute_elbo(self):
diff --git a/tests/conftest.py b/tests/conftest.py
index a6d47ec0..26daaf2e 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,18 +1,19 @@
 import sys
 import glob
-from functools import singledispatch, singledispatchmethod
+from functools import singledispatch
 
 import pytest
 from pytest_lazyfixture import lazy_fixture as lf
 from pyPLNmodels import load_model, load_plnpca
-from tests.import_fixtures import get_dict_fixtures
 from pyPLNmodels.models import PLN, _PLNPCA, PLNPCA
 
-sys.path.append("../")
 
+sys.path.append("../")
 
-pln_full_fixture = get_dict_fixtures(PLN)
-plnpca_fixture = get_dict_fixtures(_PLNPCA)
+pytest_plugins = [
+    fixture_file.replace("/", ".").replace(".py", "")
+    for fixture_file in glob.glob("src/**/tests/fixtures/[!__]*.py", recursive=True)
+]
 
 
 from tests.import_data import (
@@ -34,7 +35,7 @@ counts_real = data_real["counts"]
 
 
 def add_fixture_to_dict(my_dict, string_fixture):
-    my_dict[string_fixture] = lf(string_fixture)
+    my_dict[string_fixture] = [lf(string_fixture)]
     return my_dict
 
 
@@ -49,7 +50,7 @@ def add_list_of_fixture_to_dict(
 
 RANK = 8
 RANKS = [2, 6]
-
+instances = []
 # dict_fixtures_models = []
 
 
@@ -89,14 +90,6 @@ def convenientplnpca(
     )
 
 
-# class convenientplnpca(PLNPCA):
-#     @singledispatchmethod
-#     def __init__(self,counts, covariates=None, offsets= None, offsets_formula=None, dict_initialization=None):
-#         super().__init__(counts, covariates, offsets, offsets_formula, dict_initialization)
-
-#     def _(formula, data, offsets.)
-
-
 @convenientplnpca.register(str)
 def _(formula, data, offsets_formula=None, dict_initialization=None):
     return PLNPCA(
@@ -108,6 +101,22 @@ def _(formula, data, offsets_formula=None, dict_initialization=None):
     )
 
 
+def generate_new_model(model, *args, **kwargs):
+    name_dir = model.directory_name
+    name = model.NAME
+    if name in ("PLN", "_PLNPCA"):
+        path = model.path_to_directory + name_dir
+        init = load_model(path)
+        if name == "PLN":
+            new = PLN(*args, **kwargs, dict_initialization=init)
+        if name == "_PLNPCA":
+            new = convenient_plnpca(*args, **kwargs, dict_initialization=init)
+    if name == "PLNPCA":
+        init = load_plnpca(name_dir)
+        new = convenientplnpca(*args, **kwargs, dict_initialization=init)
+    return new
+
+
 def cache(func):
     dict_cache = {}
 
@@ -124,69 +133,65 @@ dict_fixtures = {}
 
 
 @pytest.fixture(params=params)
-@cache
 def simulated_pln_0cov_array(request):
     cls = request.param
-    pln_full = cls(counts_sim_0cov, covariates_sim_0cov, offsets_sim_0cov)
-    return pln_full
+    pln = cls(counts_sim_0cov, covariates_sim_0cov, offsets_sim_0cov)
+    return pln
 
 
-@pytest.fixture
-def simulated_fitted_pln_0cov_array(simulated_pln_0cov_array):
-    simulated_pln_0cov_array.fit()
-    return simulated_pln_0cov_array
+@pytest.fixture(params=params)
+@cache
+def simulated_fitted_pln_0cov_array(request):
+    cls = request.param
+    pln = cls(counts_sim_0cov, covariates_sim_0cov, offsets_sim_0cov)
+    pln.fit()
+    return pln
 
 
 @pytest.fixture(params=params)
 def simulated_pln_0cov_formula(request):
     cls = request.param
-    pln_full = cls("counts ~ 0", data_sim_0cov)
-    return pln_full
+    pln = cls("counts ~ 0", data_sim_0cov)
+    return pln
 
 
-@pytest.fixture
-def simulated_fitted_pln_0cov_formula(simulated_pln_0cov_formula):
-    simulated_pln_0cov_formula.fit()
-    return simulated_pln_0cov_formula
+@pytest.fixture(params=params)
+@cache
+def simulated_fitted_pln_0cov_formula(request):
+    cls = request.param
+    pln = cls("counts ~ 0", data_sim_0cov)
+    pln.fit()
+    return pln
 
 
 @pytest.fixture
 def simulated_loaded_pln_0cov_formula(simulated_fitted_pln_0cov_formula):
     simulated_fitted_pln_0cov_formula.save()
-    path = simulated_fitted_pln_0cov_formula.model_path
-    name = simulated_fitted_pln_0cov_formula.NAME
-    if name == "PLN" or name == "_PLNPCA":
-        init = load_model(path)
-    if name == "PLNPCA":
-        init = load_plnpca(path)
-    new = simulated_loaded_pln_0cov_formula.get_class(
-        "counts ~0", data_sim_0cov, dict_initialization=init
+    return generate_new_model(
+        simulated_fitted_pln_0cov_formula,
+        "counts ~ 0",
+        data_sim_0cov,
     )
-    return new
 
 
 @pytest.fixture
 def simulated_loaded_pln_0cov_array(simulated_fitted_pln_0cov_array):
     simulated_fitted_pln_0cov_array.save()
-    path = simulated_fitted_pln_0cov_array.model_path
-    name = simulated_fitted_pln_0cov_array.NAME
-    if name == "PLN" or name == "_PLNPCA":
-        init = load_model(path)
-    if name == "PLNPCA":
-        init = load_plnpca(path)
-    new = simulated_fitted_pln_0cov_array.get_class(
+    return generate_new_model(
+        simulated_fitted_pln_0cov_array,
         counts_sim_0cov,
         covariates_sim_0cov,
         offsets_sim_0cov,
-        dict_initialization=init,
     )
-    return new
 
 
 sim_pln_0cov_instance = [
     "simulated_pln_0cov_array",
     "simulated_pln_0cov_formula",
 ]
+
+instances = sim_pln_0cov_instance + instances
+
 dict_fixtures = add_list_of_fixture_to_dict(
     dict_fixtures, "sim_pln_0cov_instance", sim_pln_0cov_instance
 )
@@ -214,6 +219,7 @@ dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln_0cov", sim_p
 
 
 @pytest.fixture(params=params)
+@cache
 def simulated_pln_2cov_array(request):
     cls = request.param
     pln_full = cls(counts_sim_2cov, covariates_sim_2cov, offsets_sim_2cov)
@@ -227,7 +233,9 @@ def simulated_fitted_pln_2cov_array(simulated_pln_2cov_array):
 
 
 @pytest.fixture(params=params)
-def simulated_pln_2cov_formula():
+@cache
+def simulated_pln_2cov_formula(request):
+    cls = request.param
     pln_full = cls("counts ~ 0 + covariates", data_sim_2cov)
     return pln_full
 
@@ -241,40 +249,29 @@ def simulated_fitted_pln_2cov_formula(simulated_pln_2cov_formula):
 @pytest.fixture
 def simulated_loaded_pln_2cov_formula(simulated_fitted_pln_2cov_formula):
     simulated_fitted_pln_2cov_formula.save()
-    path = simulated_fitted_pln_2cov_formula.model_path
-    name = simulated_fitted_pln_2cov_formula.NAME
-    if name == "PLN":
-        init = load_model(path)
-    if name == "PLNPCA":
-        init = load_plnpca(path)
-    new = simulated_fitted_pln_2cov_formula.get_class(
-        "counts ~1", data_sim_2cov, dict_initialization=init
+    return generate_new_model(
+        simulated_fitted_pln_2cov_formula,
+        "counts ~0 + covariates",
+        data_sim_2cov,
     )
-    return new
 
 
 @pytest.fixture
 def simulated_loaded_pln_2cov_array(simulated_fitted_pln_2cov_array):
     simulated_fitted_pln_2cov_array.save()
-    path = simulated_fitted_pln_2cov_array.model_path
-    name = simulated_fitted_pln_2cov_array.NAME
-    if name == "PLN" or name == "_PLNPCA":
-        init = load_model(path)
-    if name == "PLNPCA":
-        init = load_model(path)
-    new = simulated_fitted_pln_2cov_array.get_class(
+    return generate_new_model(
+        simulated_fitted_pln_2cov_array,
         counts_sim_2cov,
         covariates_sim_2cov,
         offsets_sim_2cov,
-        dict_initialization=init,
     )
-    return new
 
 
 sim_pln_2cov_instance = [
     "simulated_pln_2cov_array",
     "simulated_pln_2cov_formula",
 ]
+instances = sim_pln_2cov_instance + instances
 
 dict_fixtures = add_list_of_fixture_to_dict(
     dict_fixtures, "sim_pln_2cov_instance", sim_pln_2cov_instance
@@ -303,6 +300,7 @@ dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln_2cov", sim_p
 
 
 @pytest.fixture(params=params)
+@cache
 def real_pln_intercept_array(request):
     cls = request.param
     pln_full = cls(counts_real)
@@ -316,9 +314,9 @@ def real_fitted_pln_intercept_array(real_pln_intercept_array):
 
 
 @pytest.fixture(params=params)
+@cache
 def real_pln_intercept_formula(request):
     cls = request.param
-    print("cls:", cls)
     pln_full = cls("counts ~ 1", data_real)
     return pln_full
 
@@ -332,37 +330,23 @@ def real_fitted_pln_intercept_formula(real_pln_intercept_formula):
 @pytest.fixture
 def real_loaded_pln_intercept_formula(real_fitted_pln_intercept_formula):
     real_fitted_pln_intercept_formula.save()
-    path = real_fitted_pln_intercept_formula.model_path
-    name = real_fitted_pln_intercept_formula.NAME
-    if name == "PLN" or name == "_PLNPCA":
-        init = load_model(path)
-    if name == "PLNPCA":
-        init = load_plnpca(path)
-    new = real_fitted_pln_intercept_formula.get_class(
-        "counts~ 1", data_real, dict_initialization=init
+    return generate_new_model(
+        real_fitted_pln_intercept_formula, "counts ~ 1", data_real
     )
-    return new
 
 
 @pytest.fixture
 def real_loaded_pln_intercept_array(real_fitted_pln_intercept_array):
     real_fitted_pln_intercept_array.save()
-    path = real_fitted_pln_intercept_array.model_path
-    name = real_fitted_pln_intercept_array.NAME
-    if name == "PLN" or name == "_PLNPCA":
-        init = load_model(path)
-    if name == "PLNPCA":
-        init = load_plnpca(path)
-    new = real_fitted_pln_intercept_array.get_class(
-        counts_real, dict_initialization=init
-    )
-    return new
+    return generate_new_model(real_fitted_pln_intercept_array, counts_real)
 
 
 real_pln_instance = [
     "real_pln_intercept_array",
     "real_pln_intercept_formula",
 ]
+instances = real_pln_instance + instances
+
 dict_fixtures = add_list_of_fixture_to_dict(
     dict_fixtures, "real_pln_instance", real_pln_instance
 )
@@ -407,14 +391,11 @@ dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "real_pln", real_pln)
 sim_pln = sim_pln_2cov + sim_pln_0cov
 dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln", sim_pln)
 
-all_pln = real_pln + sim_pln
+all_pln = real_pln + sim_pln + instances
+dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "instances", instances)
 dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "all_pln", all_pln)
 
 
 for string_fixture in all_pln:
+    print("string_fixture", string_fixture)
     dict_fixtures = add_fixture_to_dict(dict_fixtures, string_fixture)
-
-pytest_plugins = [
-    fixture_file.replace("/", ".").replace(".py", "")
-    for fixture_file in glob.glob("src/**/tests/fixtures/[!__]*.py", recursive=True)
-]
diff --git a/tests/import_fixtures.py b/tests/import_fixtures.py
index e2d9b1ef..e69de29b 100644
--- a/tests/import_fixtures.py
+++ b/tests/import_fixtures.py
@@ -1,285 +0,0 @@
-import pytest
-from pytest_lazyfixture import lazy_fixture as lf
-from pyPLNmodels import load_model
-from tests.import_data import (
-    data_sim_0cov,
-    data_sim_2cov,
-    data_real,
-)
-
-counts_sim_0cov = data_sim_0cov["counts"]
-covariates_sim_0cov = data_sim_0cov["covariates"]
-offsets_sim_0cov = data_sim_0cov["offsets"]
-
-counts_sim_2cov = data_sim_2cov["counts"]
-covariates_sim_2cov = data_sim_2cov["covariates"]
-offsets_sim_2cov = data_sim_2cov["offsets"]
-
-counts_real = data_real["counts"]
-
-
-def add_fixture_to_dict(my_dict, string_fixture):
-    my_dict[string_fixture] = lf(string_fixture)
-    return my_dict
-
-
-def add_list_of_fixture_to_dict(
-    my_dict, name_of_list_of_fixtures, list_of_string_fixtures
-):
-    my_dict[name_of_list_of_fixtures] = []
-    for string_fixture in list_of_string_fixtures:
-        my_dict[name_of_list_of_fixtures].append(lf(string_fixture))
-    return my_dict
-
-
-def get_dict_fixtures(PLNor_PLNPCA):
-    dict_fixtures = {}
-
-    @pytest.fixture
-    def simulated_pln_0cov_array():
-        pln_full = PLNor_PLNPCA(counts_sim_0cov, covariates_sim_0cov, offsets_sim_0cov)
-        return pln_full
-
-    @pytest.fixture
-    def simulated_fitted_pln_0cov_array(simulated_pln_0cov_array):
-        simulated_pln_0cov_array.fit()
-        return simulated_pln_0cov_array
-
-    @pytest.fixture
-    def simulated_pln_0cov_formula():
-        pln_full = PLNor_PLNPCA("counts ~ 0", data_sim_0cov)
-        return pln_full
-
-    @pytest.fixture
-    def simulated_fitted_pln_0cov_formula(simulated_pln_0cov_formula):
-        simulated_pln_0cov_array.fit()
-        return simulated_pln_0cov_formula
-
-    @pytest.fixture
-    def simulated_loaded_pln_0cov_formula(simulated_fitted_pln_0cov_formula):
-        simulated_fitted_pln_0cov_formula.save()
-        if simulated_fitted_pln_0cov_formula.NAME == "PLN":
-            init = load_model("PLN_nbcov_0")
-        if simulated_fitted_pln_0cov_formula.NAME == "_PLNPCA":
-            init = load_model(f"PLNPCA_rank_{simulated_fitted_pln_0cov_formula.rank}")
-        new = PLNor_PLNPCA("counts ~0", data_sim_0cov, dict_initialization=init)
-        return new
-
-    @pytest.fixture
-    def simulated_loaded_pln_0cov_array(simulated_fitted_pln_0cov_array):
-        simulated_fitted_pln_0cov_array.save()
-        if simulated_fitted_pln_0cov_array.NAME == "PLN":
-            init = load_model("PLN_nbcov_0")
-        if simulated_fitted_pln_0cov_array.NAME == "_PLNPCA":
-            init = load_model(f"PLNPCA_rank_{simulated_fitted_pln_0cov_array.rank}")
-        new = PLNor_PLNPCA(
-            counts_sim_0cov,
-            covariates_sim_0cov,
-            offsets_sim_0cov,
-            dict_initialization=init,
-        )
-        return new
-
-    sim_pln_0cov_instance = [
-        "simulated_pln_0cov_array",
-        "simulated_pln_0cov_formula",
-    ]
-    dict_fixtures = add_list_of_fixture_to_dict(
-        dict_fixtures, "sim_pln_0cov_instance", sim_pln_0cov_instance
-    )
-
-    sim_pln_0cov_fitted = [
-        "simulated_fitted_pln_0cov_array",
-        "simulated_fitted_pln_0cov_formula",
-    ]
-
-    dict_fixtures = add_list_of_fixture_to_dict(
-        dict_fixtures, "sim_pln_0cov_fitted", sim_pln_0cov_fitted
-    )
-
-    sim_pln_0cov_loaded = [
-        "simulated_loaded_pln_0cov_array",
-        "simulated_loaded_pln_0cov_formula",
-    ]
-
-    dict_fixtures = add_list_of_fixture_to_dict(
-        dict_fixtures, "sim_pln_0cov_loaded", sim_pln_0cov_loaded
-    )
-
-    sim_pln_0cov = sim_pln_0cov_instance + sim_pln_0cov_fitted + sim_pln_0cov_loaded
-    dict_fixtures = add_list_of_fixture_to_dict(
-        dict_fixtures, "sim_pln_0cov", sim_pln_0cov
-    )
-
-    @pytest.fixture
-    def simulated_pln_2cov_array():
-        pln_full = PLNor_PLNPCA(counts_sim_2cov, covariates_sim_2cov, offsets_sim_2cov)
-        return pln_full
-
-    @pytest.fixture
-    def simulated_fitted_pln_2cov_array(simulated_pln_2cov_array):
-        simulated_pln_2cov_array.fit()
-        return simulated_pln_2cov_array
-
-    @pytest.fixture
-    def simulated_pln_2cov_formula():
-        pln_full = PLNor_PLNPCA("counts ~ 0 + covariates", data_sim_2cov)
-        return pln_full
-
-    @pytest.fixture
-    def simulated_fitted_pln_2cov_formula(simulated_pln_2cov_formula):
-        simulated_pln_2cov_formula.fit()
-        return simulated_pln_2cov_formula
-
-    @pytest.fixture
-    def simulated_loaded_pln_2cov_formula(simulated_fitted_pln_2cov_formula):
-        simulated_fitted_pln_2cov_formula.save()
-        if simulated_fitted_pln_2cov_formula.NAME == "PLN":
-            init = load_model("PLN_nbcov_2")
-        if simulated_fitted_pln_2cov_formula.NAME == "_PLNPCA":
-            init = load_model(f"PLNPCA_rank_{simulated_fitted_pln_2cov_formula.rank}")
-        new = PLNor_PLNPCA("counts ~2", data_sim_2cov, dict_initialization=init)
-        return new
-
-    @pytest.fixture
-    def simulated_loaded_pln_2cov_array(simulated_fitted_pln_2cov_array):
-        simulated_fitted_pln_2cov_array.save()
-        if simulated_fitted_pln_2cov_array.NAME == "PLN":
-            init = load_model("PLN_nbcov_2")
-        if simulated_fitted_pln_2cov_array.NAME == "_PLNPCA":
-            init = load_model(f"PLNPCA_rank_{simulated_fitted_pln_2cov_array.rank}")
-        new = PLNor_PLNPCA(
-            counts_sim_2cov,
-            covariates_sim_2cov,
-            offsets_sim_2cov,
-            dict_initialization=init,
-        )
-        return new
-
-    sim_pln_2cov_instance = [
-        "simulated_pln_2cov_array",
-        "simulated_pln_2cov_formula",
-    ]
-
-    dict_fixtures = add_list_of_fixture_to_dict(
-        dict_fixtures, "sim_pln_2cov_instance", sim_pln_2cov_instance
-    )
-
-    sim_pln_2cov_fitted = [
-        "simulated_fitted_pln_2cov_array",
-        "simulated_fitted_pln_2cov_formula",
-    ]
-
-    dict_fixtures = add_list_of_fixture_to_dict(
-        dict_fixtures, "sim_pln_2cov_fitted", sim_pln_2cov_fitted
-    )
-
-    sim_pln_2cov_loaded = [
-        "simulated_loaded_pln_2cov_array",
-        "simulated_loaded_pln_2cov_formula",
-    ]
-
-    dict_fixtures = add_list_of_fixture_to_dict(
-        dict_fixtures, "sim_pln_2cov_loaded", sim_pln_2cov_loaded
-    )
-
-    sim_pln_2cov = sim_pln_2cov_instance + sim_pln_2cov_fitted + sim_pln_2cov_loaded
-    dict_fixtures = add_list_of_fixture_to_dict(
-        dict_fixtures, "sim_pln_2cov", sim_pln_2cov
-    )
-
-    @pytest.fixture
-    def real_pln_intercept_array():
-        pln_full = PLNor_PLNPCA(counts_real)
-        return pln_full
-
-    @pytest.fixture
-    def real_fitted_pln_intercept_array(real_pln_intercept_array):
-        real_pln_intercept_array.fit()
-        return real_pln_intercept_array
-
-    @pytest.fixture
-    def real_pln_intercept_formula():
-        pln_full = PLNor_PLNPCA("counts ~ 1", data_real)
-        return pln_full
-
-    @pytest.fixture
-    def real_fitted_pln_intercept_formula(real_pln_intercept_formula):
-        real_pln_intercept_formula.fit()
-        return real_pln_intercept_formula
-
-    @pytest.fixture
-    def real_loaded_pln_intercept_formula(real_fitted_pln_intercept_formula):
-        real_fitted_pln_intercept_formula.save()
-        if real_fitted_pln_intercept_formula.NAME == "PLN":
-            init = load_model("PLN_nbcov_2")
-        if real_fitted_pln_intercept_formula.NAME == "_PLNPCA":
-            init = load_model(f"PLNPCA_rank_{real_fitted_pln_intercept_formula.rank}")
-        new = PLNor_PLNPCA("counts ~2", data_real, dict_initialization=init)
-        return new
-
-    @pytest.fixture
-    def real_loaded_pln_intercept_array(real_fitted_pln_intercept_array):
-        real_fitted_pln_intercept_array.save()
-        if real_fitted_pln_intercept_array.NAME == "PLN":
-            init = load_model("PLN_nbcov_2")
-        if real_fitted_pln_intercept_array.NAME == "_PLNPCA":
-            init = load_model(f"PLNPCA_rank_{real_fitted_pln_intercept_array.rank}")
-        new = PLNor_PLNPCA(counts_real, dict_initialization=init)
-        return new
-
-    real_pln_instance = [
-        "real_pln_intercept_array",
-        "real_pln_intercept_formula",
-    ]
-    dict_fixtures = add_list_of_fixture_to_dict(
-        dict_fixtures, "real_pln_instance", real_pln_instance
-    )
-
-    real_pln_fitted = [
-        "real_fitted_pln_intercept_array",
-        "real_fitted_pln_intercept_formula",
-    ]
-    dict_fixtures = add_list_of_fixture_to_dict(
-        dict_fixtures, "real_pln_fitted", real_pln_fitted
-    )
-
-    real_pln_loaded = [
-        "real_loaded_pln_intercept_array",
-        "real_loaded_pln_intercept_formula",
-    ]
-    dict_fixtures = add_list_of_fixture_to_dict(
-        dict_fixtures, "real_pln_loaded", real_pln_loaded
-    )
-
-    sim_loaded_pln = sim_pln_0cov_loaded + sim_pln_2cov_loaded
-
-    loaded_pln = real_pln_loaded + sim_loaded_pln
-    dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "loaded_pln", loaded_pln)
-
-    simulated_pln_fitted = sim_pln_0cov_fitted + sim_pln_2cov_fitted
-    dict_fixtures = add_list_of_fixture_to_dict(
-        dict_fixtures, "simulated_pln_fitted", simulated_pln_fitted
-    )
-
-    fitted_pln = real_pln_fitted + simulated_pln_fitted
-    dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "fitted_pln", fitted_pln)
-
-    loaded_and_fitted_pln = fitted_pln + loaded_pln
-    dict_fixtures = add_list_of_fixture_to_dict(
-        dict_fixtures, "loaded_and_fitted_pln", loaded_and_fitted_pln
-    )
-
-    real_pln = real_pln_instance + real_pln_fitted + real_pln_loaded
-    dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "real_pln", real_pln)
-
-    sim_pln = sim_pln_2cov + sim_pln_0cov
-    dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln", sim_pln)
-
-    all_pln = real_pln + sim_pln
-    dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "all_pln", all_pln)
-
-    for string_fixture in all_pln:
-        dict_fixtures = add_fixture_to_dict(dict_fixtures, string_fixture)
-
-    return dict_fixtures
diff --git a/tests/test_args.py b/tests/test_args.py
index 16c8a73d..2708851d 100644
--- a/tests/test_args.py
+++ b/tests/test_args.py
@@ -49,3 +49,18 @@ def test_pandas_init(instance):
 @pytest.mark.parametrize("instance", all_instances)
 def test_numpy_init(instance):
     instance.fit(counts_sim.numpy(), covariates_sim.numpy(), offsets_sim.numpy())
+
+
+@pytest.mark.parametrize("sim_pln", simulated_any_pln)
+def test_only_counts(sim_pln):
+    sim_pln.fit()
+
+
+@pytest.mark.parametrize("sim_pln", simulated_any_pln)
+def test_only_counts_and_offsets(sim_pln):
+    sim_pln.fit(counts=counts_sim, offsets=offsets_sim)
+
+
+@pytest.mark.parametrize("sim_pln", simulated_any_pln)
+def test_only_Y_and_cov(sim_pln):
+    sim_pln.fit(counts=counts_sim, covariates=covariates_sim)
diff --git a/tests/test_common.py b/tests/test_common.py
index 97f81357..a403ac89 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -2,40 +2,11 @@ import os
 import functools
 
 import torch
-import numpy as np
-import pandas as pd
 import pytest
-from pytest_lazyfixture import lazy_fixture as lf
-
-from pyPLNmodels.models import PLN, _PLNPCA
-from tests.utils import MSE
-
-# from tests.import_fixtures import get_dict_fixtures
 
 from tests.conftest import dict_fixtures
 
 
-# dict_fixtures_pln_full = dict_fixtures_models[0]
-# dict_fixtures_plnpca = dict_fixtures_models[1]
-# dict_fixturesplnpca = dict_fixtures_models[2]
-
-
-def get_pln_and_plncpca_fixtures(key):
-    return dict_fixtures_pln_full[key] + dict_fixtures_plnpca[key]
-
-
-def get_pca_fixtures(key):
-    return dict_fixtures_plnpca[key] + dict_fixturesplnpca[key]
-
-
-def get_all_fixtures(key):
-    return (
-        dict_fixtures_plnpca[key]
-        + dict_fixtures_pln_full[key]
-        + dict_fixturesplnpca[key]
-    )
-
-
 def filter_models(models_name):
     def decorator(my_test):
         @functools.wraps(my_test)
@@ -50,26 +21,22 @@ def filter_models(models_name):
     return decorator
 
 
-# @pytest.mark.parametrize("any_pln", [dict_fixtures["simulated_pln_0cov_array"]])
-# @pytest.mark.parametrize("any_pln", [dict_fixtures["simulated_pln_0cov_array"]])
+@pytest.mark.parametrize("any_pln", dict_fixtures["fitted_pln"])
 @filter_models(["PLN", "_PLNPCA"])
-def test_properties(simulated_fitted_pln_0cov_array):
-    assert hasattr(simulated_fitted_pln_0cov_array, "model_parameters")
-    assert hasattr(simulated_fitted_pln_0cov_array, "latent_parameters")
-    print("model_param", simulated_fitted_pln_0cov_array.model_parameters)
-    assert hasattr(simulated_fitted_pln_0cov_array, "latent_variables")
-    assert hasattr(simulated_fitted_pln_0cov_array, "optim_parameters")
+def test_properties(any_pln):
+    assert hasattr(any_pln, "latent_parameters")
+    assert hasattr(any_pln, "latent_variables")
+    assert hasattr(any_pln, "optim_parameters")
+    assert hasattr(any_pln, "model_parameters")
 
 
-@pytest.mark.parametrize("any_pln", dict_fixtures["fitted_pln"])
+@pytest.mark.parametrize("any_pln", dict_fixtures["loaded_and_fitted_pln"])
 def test_print(any_pln):
     print(any_pln)
 
 
-print("len :", len(dict_fixtures["all_pln"]))
-
-"""
-@pytest.mark.parametrize("any_pln", all_fitted_models)
+@pytest.mark.parametrize("any_pln", dict_fixtures["fitted_pln"])
+@filter_models(["PLN", "_PLNPCA"])
 def test_show_coef_transform_covariance_pcaprojected(any_pln):
     any_pln.show()
     any_pln.plotargs.show_loss(savefig=True)
@@ -83,52 +50,33 @@ def test_show_coef_transform_covariance_pcaprojected(any_pln):
         any_pln.pca_projected_latent_variables(n_components=any_pln.dim + 1)
 
 
-@pytest.mark.parametrize("sim_pln", simulated_any_pln)
-def test_predict_simulated(sim_pln):
-    X = torch.randn((sim_pln.n_samples, sim_pln.nb_cov - 1))
-    prediction = sim_pln.predict(X)
-    expected = (
-        torch.stack((torch.ones(sim_pln.n_samples, 1), X), axis=1).squeeze()
-        @ sim_pln.coef
-    )
-    assert torch.all(torch.eq(expected, prediction))
+print("loaded and fitted ", dict_fixtures["loaded_and_fitted_pln"])
 
 
-@pytest.mark.parametrize("real_pln", real_any_pln)
-def test_predict_real(real_pln):
-    prediction = real_pln.predict()
-    expected = torch.ones(real_pln.n_samples, 1) @ real_pln.coef
-    assert torch.all(torch.eq(expected, prediction))
+@pytest.mark.parametrize("sim_pln", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLN", "_PLNPCA"])
+def test_predict_simulated(sim_pln):
+    if sim_pln.nb_cov == 0:
+        assert sim_pln.predict() is None
+        with pytest.raises(AttributeError):
+            sim_pln.predict(1)
+    else:
+        X = torch.randn((sim_pln.n_samples, sim_pln.nb_cov))
+        prediction = sim_pln.predict(X)
+        expected = X @ sim_pln.coef
+        assert torch.all(torch.eq(expected, prediction))
 
 
-@pytest.mark.parametrize("any_pln", all_fitted_models)
-def test_print(any_pln):
-    print(any_pln)
+print("instances:", dict_fixtures["instances"])
 
 
-@pytest.mark.parametrize("any_instance_pln", all_instances)
+@pytest.mark.parametrize("any_instance_pln", dict_fixtures["instances"])
 def test_verbose(any_instance_pln):
-    any_instance_pln.fit(
-        counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim, verbose=True
-    )
-
-
-@pytest.mark.parametrize("sim_pln", simulated_any_pln)
-def test_only_counts(sim_pln):
-    sim_pln.fit(counts=counts_sim)
+    any_instance_pln.fit(verbose=True, tol=0.1)
 
 
-@pytest.mark.parametrize("sim_pln", simulated_any_pln)
-def test_only_counts_and_offsets(sim_pln):
-    sim_pln.fit(counts=counts_sim, offsets=offsets_sim)
-
-
-@pytest.mark.parametrize("sim_pln", simulated_any_pln)
-def test_only_Y_and_cov(sim_pln):
-    sim_pln.fit(counts=counts_sim, covariates=covariates_sim)
-
-
-@pytest.mark.parametrize("simulated_fitted_any_pln", simulated_any_pln)
+@pytest.mark.parametrize("simulated_fitted_any_pln", dict_fixtureskk["sim_pln"])
+@filter_models(["PLN", "_PLNPCA"])
 def test_find_right_covariance(simulated_fitted_any_pln):
     mse_covariance = MSE(simulated_fitted_any_pln.covariance - true_covariance)
     assert mse_covariance < 0.05
@@ -140,6 +88,8 @@ def test_find_right_coef(sim_pln):
     assert mse_coef < 0.1
 
 
+"""
+
 def test_number_of_iterations_pln_full(simulated_fitted_pln_full_0cov):
     nb_iterations = len(simulated_fitted_pln_full_0cov.elbos_list)
     assert 50 < nb_iterations < 300
-- 
GitLab


From f83b222d31a4d7b006578c657c01f2ca0315ec1a Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 11 May 2023 09:52:25 +0200
Subject: [PATCH 18/24] pass more tests.

---
 tests/conftest.py    | 13 ++++++++++++-
 tests/test_common.py | 42 ++++++++++++++++++++++++++++++------------
 2 files changed, 42 insertions(+), 13 deletions(-)

diff --git a/tests/conftest.py b/tests/conftest.py
index 26daaf2e..e9adf442 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -111,6 +111,8 @@ def generate_new_model(model, *args, **kwargs):
             new = PLN(*args, **kwargs, dict_initialization=init)
         if name == "_PLNPCA":
             new = convenient_plnpca(*args, **kwargs, dict_initialization=init)
+            print("now", new.nb_cov)
+            x
     if name == "PLNPCA":
         init = load_plnpca(name_dir)
         new = convenientplnpca(*args, **kwargs, dict_initialization=init)
@@ -330,6 +332,7 @@ def real_fitted_pln_intercept_formula(real_pln_intercept_formula):
 @pytest.fixture
 def real_loaded_pln_intercept_formula(real_fitted_pln_intercept_formula):
     real_fitted_pln_intercept_formula.save()
+    print("before", real_fitted_pln_intercept_formula.nb_cov)
     return generate_new_model(
         real_fitted_pln_intercept_formula, "counts ~ 1", data_real
     )
@@ -376,10 +379,18 @@ simulated_pln_fitted = sim_pln_0cov_fitted + sim_pln_2cov_fitted
 dict_fixtures = add_list_of_fixture_to_dict(
     dict_fixtures, "simulated_pln_fitted", simulated_pln_fitted
 )
-
 fitted_pln = real_pln_fitted + simulated_pln_fitted
 dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "fitted_pln", fitted_pln)
 
+
+loaded_and_fitted_sim_pln = simulated_pln_fitted + sim_loaded_pln
+loaded_and_fitted_real_pln = real_pln_fitted + real_pln_loaded
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "loaded_and_fitted_real_pln", loaded_and_fitted_real_pln
+)
+dict_fixtures = add_list_of_fixture_to_dict(
+    dict_fixtures, "loaded_and_fitted_sim_pln", loaded_and_fitted_sim_pln
+)
 loaded_and_fitted_pln = fitted_pln + loaded_pln
 dict_fixtures = add_list_of_fixture_to_dict(
     dict_fixtures, "loaded_and_fitted_pln", loaded_and_fitted_pln
diff --git a/tests/test_common.py b/tests/test_common.py
index a403ac89..18bbd16c 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -5,6 +5,9 @@ import torch
 import pytest
 
 from tests.conftest import dict_fixtures
+from tests.utils import MSE
+
+from tests.import_data import true_sim_0cov, true_sim_2cov
 
 
 def filter_models(models_name):
@@ -50,9 +53,6 @@ def test_show_coef_transform_covariance_pcaprojected(any_pln):
         any_pln.pca_projected_latent_variables(n_components=any_pln.dim + 1)
 
 
-print("loaded and fitted ", dict_fixtures["loaded_and_fitted_pln"])
-
-
 @pytest.mark.parametrize("sim_pln", dict_fixtures["loaded_and_fitted_pln"])
 @filter_models(["PLN", "_PLNPCA"])
 def test_predict_simulated(sim_pln):
@@ -67,29 +67,47 @@ def test_predict_simulated(sim_pln):
         assert torch.all(torch.eq(expected, prediction))
 
 
-print("instances:", dict_fixtures["instances"])
-
-
 @pytest.mark.parametrize("any_instance_pln", dict_fixtures["instances"])
 def test_verbose(any_instance_pln):
     any_instance_pln.fit(verbose=True, tol=0.1)
 
 
-@pytest.mark.parametrize("simulated_fitted_any_pln", dict_fixtureskk["sim_pln"])
+@pytest.mark.parametrize(
+    "simulated_fitted_any_pln", dict_fixtures["loaded_and_fitted_sim_pln"]
+)
 @filter_models(["PLN", "_PLNPCA"])
 def test_find_right_covariance(simulated_fitted_any_pln):
+    if simulated_fitted_any_pln.nb_cov == 0:
+        true_covariance = true_sim_0cov["Sigma"]
+    elif simulated_fitted_any_pln.nb_cov == 2:
+        true_covariance = true_sim_2cov["Sigma"]
     mse_covariance = MSE(simulated_fitted_any_pln.covariance - true_covariance)
     assert mse_covariance < 0.05
 
 
-@pytest.mark.parametrize("sim_pln", simulated_any_pln)
-def test_find_right_coef(sim_pln):
-    mse_coef = MSE(sim_pln.coef - true_coef)
-    assert mse_coef < 0.1
+@pytest.mark.parametrize(
+    "real_fitted_and_loaded_pln", dict_fixtures["loaded_and_fitted_real_pln"]
+)
+@filter_models(["PLN", "_PLNPCA"])
+def test_right_covariance_shape(real_fitted_and_loaded_pln):
+    assert real_fitted_and_loaded_pln.covariance.shape == (100, 100)
 
 
-"""
+@pytest.mark.parametrize(
+    "simulated_fitted_any_pln", dict_fixtures["loaded_and_fitted_pln"]
+)
+@filter_models(["PLN", "_PLNPCA"])
+def test_find_right_coef(simulated_fitted_any_pln):
+    if simulated_fitted_any_pln.nb_cov == 2:
+        true_coef = true_sim_2cov["beta"]
+        mse_coef = MSE(simulated_fitted_any_pln.coef - true_coef)
+        assert mse_coef < 0.1
+    elif simulated_fitted_any_pln.nb_cov == 0:
+        print("nb cov", simulated_fitted_any_pln.nb_cov)
+        assert simulated_fitted_any_pln.coef is None
 
+
+"""
 def test_number_of_iterations_pln_full(simulated_fitted_pln_full_0cov):
     nb_iterations = len(simulated_fitted_pln_full_0cov.elbos_list)
     assert 50 < nb_iterations < 300
-- 
GitLab


From 03eb9e554b1c6b5aa07e86386a95523199dba25d Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Thu, 11 May 2023 09:53:15 +0200
Subject: [PATCH 19/24] returns None instead of zero when covariates is none
 for beta.

---
 pyPLNmodels/_closed_forms.py | 2 +-
 pyPLNmodels/_utils.py        | 6 ++----
 pyPLNmodels/models.py        | 7 +------
 3 files changed, 4 insertions(+), 11 deletions(-)

diff --git a/pyPLNmodels/_closed_forms.py b/pyPLNmodels/_closed_forms.py
index 087bc386..63d18014 100644
--- a/pyPLNmodels/_closed_forms.py
+++ b/pyPLNmodels/_closed_forms.py
@@ -16,7 +16,7 @@ def closed_formula_covariance(covariates, latent_mean, latent_var, coef, n_sampl
 def closed_formula_coef(covariates, latent_mean):
     """Closed form for coef for the M step for the noPCA model."""
     if covariates is None:
-        return 0
+        return None
     return torch.mm(
         torch.mm(torch.inverse(torch.mm(covariates.T, covariates)), covariates.T),
         latent_mean,
diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index 23a02fd2..f7da25e7 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -630,10 +630,8 @@ def extract_data_from_formula(formula, data):
     dmatrix = dmatrices(formula, data=data)
     counts = dmatrix[0]
     covariates = dmatrix[1]
-    if covariates.size > 0:
-        pass
-        covariates = remove_useless_intercepts(covariates)
-    else:
+    print("covariates size:", covariates.size)
+    if covariates.size == 0:
         covariates = None
     offsets = data.get("offsets", None)
     return counts, covariates, offsets
diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index cdf11513..0693f999 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -104,12 +104,7 @@ class _PLN(ABC):
         offsets_formula="logsum",
         dict_initialization=None,
     ):
-        dmatrix = dmatrices(formula, data=data)
-        counts = dmatrix[0]
-        covariates = dmatrix[1]
-        if len(covariates) > 0:
-            covariates = remove_useless_intercepts(covariates)
-        offsets = data.get("offsets", None)
+        counts, covariates, offsets = extract_data_from_formula(formula, data)
         self.__init__(counts, covariates, offsets, offsets_formula, dict_initialization)
 
     def set_init_parameters(self, dict_initialization):
-- 
GitLab


From 35f62def40cd8be908fe78b97033b7063bd29dcb Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Fri, 12 May 2023 17:58:49 +0200
Subject: [PATCH 20/24] continue to implement the tests

---
 pyPLNmodels/models.py |  2 +-
 tests/conftest.py     | 12 ++++---
 tests/test_common.py  | 74 ++++++++++++++++++++-----------------------
 3 files changed, 43 insertions(+), 45 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index 0693f999..c7b4c7a3 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -439,7 +439,7 @@ class _PLN(ABC):
                 pd.DataFrame(np.array(value.cpu().detach())).to_csv(
                     filename, header=None, index=None
                 )
-            else:
+            elif value is not None:
                 pd.DataFrame(np.array([value])).to_csv(
                     filename, header=None, index=None
                 )
diff --git a/tests/conftest.py b/tests/conftest.py
index e9adf442..904240c9 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -3,6 +3,7 @@ import glob
 from functools import singledispatch
 
 import pytest
+import torch
 from pytest_lazyfixture import lazy_fixture as lf
 from pyPLNmodels import load_model, load_plnpca
 from pyPLNmodels.models import PLN, _PLNPCA, PLNPCA
@@ -111,8 +112,6 @@ def generate_new_model(model, *args, **kwargs):
             new = PLN(*args, **kwargs, dict_initialization=init)
         if name == "_PLNPCA":
             new = convenient_plnpca(*args, **kwargs, dict_initialization=init)
-            print("now", new.nb_cov)
-            x
     if name == "PLNPCA":
         init = load_plnpca(name_dir)
         new = convenientplnpca(*args, **kwargs, dict_initialization=init)
@@ -305,7 +304,7 @@ dict_fixtures = add_list_of_fixture_to_dict(dict_fixtures, "sim_pln_2cov", sim_p
 @cache
 def real_pln_intercept_array(request):
     cls = request.param
-    pln_full = cls(counts_real)
+    pln_full = cls(counts_real, covariates=torch.ones((counts_real.shape[0], 1)))
     return pln_full
 
 
@@ -332,7 +331,6 @@ def real_fitted_pln_intercept_formula(real_pln_intercept_formula):
 @pytest.fixture
 def real_loaded_pln_intercept_formula(real_fitted_pln_intercept_formula):
     real_fitted_pln_intercept_formula.save()
-    print("before", real_fitted_pln_intercept_formula.nb_cov)
     return generate_new_model(
         real_fitted_pln_intercept_formula, "counts ~ 1", data_real
     )
@@ -341,7 +339,11 @@ def real_loaded_pln_intercept_formula(real_fitted_pln_intercept_formula):
 @pytest.fixture
 def real_loaded_pln_intercept_array(real_fitted_pln_intercept_array):
     real_fitted_pln_intercept_array.save()
-    return generate_new_model(real_fitted_pln_intercept_array, counts_real)
+    return generate_new_model(
+        real_fitted_pln_intercept_array,
+        counts_real,
+        covariates=torch.ones((counts_real.shape[0], 1)),
+    )
 
 
 real_pln_instance = [
diff --git a/tests/test_common.py b/tests/test_common.py
index 18bbd16c..bfa94a47 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -103,80 +103,77 @@ def test_find_right_coef(simulated_fitted_any_pln):
         mse_coef = MSE(simulated_fitted_any_pln.coef - true_coef)
         assert mse_coef < 0.1
     elif simulated_fitted_any_pln.nb_cov == 0:
-        print("nb cov", simulated_fitted_any_pln.nb_cov)
         assert simulated_fitted_any_pln.coef is None
 
 
-"""
-def test_number_of_iterations_pln_full(simulated_fitted_pln_full_0cov):
-    nb_iterations = len(simulated_fitted_pln_full_0cov.elbos_list)
+@pytest.mark.parametrize("fitted_pln", dict_fixtures["fitted_pln"])
+@filter_models(["PLN"])
+def test_number_of_iterations_pln_full(fitted_pln):
+    nb_iterations = len(fitted_pln.elbos_list)
     assert 50 < nb_iterations < 300
 
 
-def test_computable_elbopca(instance__plnpca, simulated_fitted__plnpca):
-    instance__plnpca.counts = simulated_fitted__plnpca.counts
-    instance__plnpca.covariates = simulated_fitted__plnpca.covariates
-    instance__plnpca.offsets = simulated_fitted__plnpca.offsets
-    instance__plnpca.latent_mean = simulated_fitted__plnpca.latent_mean
-    instance__plnpca.latent_var = simulated_fitted__plnpca.latent_var
-    instance__plnpca.components = simulated_fitted__plnpca.components
-    instance__plnpca.coef = simulated_fitted__plnpca.coef
-    instance__plnpca.compute_elbo()
-
-
-def test_computable_elbo_full(instance_pln_full, simulated_fitted_pln_full_0cov):
-    instance_pln_full.counts = simulated_fitted_pln_full_0cov.counts
-    instance_pln_full.covariates = simulated_fitted_pln_full_0cov.covariates
-    instance_pln_full.offsets = simulated_fitted_pln_full_0cov.offsets
-    instance_pln_full.latent_mean = simulated_fitted_pln_full_0cov.latent_mean
-    instance_pln_full.latent_var = simulated_fitted_pln_full_0cov.latent_var
-    instance_pln_full.covariance = simulated_fitted_pln_full_0cov.covariance
-    instance_pln_full.coef = simulated_fitted_pln_full_0cov.coef
-    instance_pln_full.compute_elbo()
+@pytest.mark.parametrize("fitted_pln", dict_fixtures["fitted_pln"])
+@filter_models(["_PLNPCA"])
+def test_number_of_iterations_plnpca(fitted_pln):
+    nb_iterations = len(fitted_pln.elbos_list)
+    assert 100 < nb_iterations < 5000
 
 
-def test_fail_count_setter(simulated_fitted_pln_full_0cov):
+@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLN"])
+def test_fail_count_setter(pln):
     wrong_counts = torch.randint(size=(10, 5), low=0, high=10)
     with pytest.raises(Exception):
-        simulated_fitted_pln_full_0cov.counts = wrong_counts
+        pln.counts = wrong_counts
 
 
-@pytest.mark.parametrize("any_pln", all_fitted_models)
-def test_setter_with_numpy(any_pln):
-    np_counts = any_pln.counts.numpy()
-    any_pln.counts = np_counts
+@pytest.mark.parametrize("pln", dict_fixtures["all_pln"])
+@filter_models(["PLN", "PLNPCA"])
+def test_setter_with_numpy(pln):
+    np_counts = pln.counts.numpy()
+    pln.counts = np_counts
 
 
-@pytest.mark.parametrize("any_pln", all_fitted_models)
-def test_setter_with_pandas(any_pln):
+@pytest.mark.parametrize("pln", dict_fixtures["all_pln"])
+@filter_models(["PLN", "PLNPCA"])
+def test_setter_with_pandas(pln):
     pd_counts = pd.DataFrame(any_pln.counts.numpy())
     any_pln.counts = pd_counts
 
 
-@pytest.mark.parametrize("instance", all_instances)
+@pytest.mark.parametrize("instance", dict_fixtures["instances"])
 def test_random_init(instance):
-    instance.fit(counts_sim, covariates_sim, offsets_sim, do_smart_init=False)
+    instance.fit(do_smart_init=False)
 
 
-@pytest.mark.parametrize("instance", all_instances)
+"""
+@pytest.mark.parametrize("instance", dict_fixtures["instances"])
 def test_print_end_of_fitting_message(instance):
     instance.fit(counts_sim, covariates_sim, offsets_sim, nb_max_iteration=4)
 
 
 @pytest.mark.parametrize("any_pln", all_fitted_models)
+@filter_models(["PLN", "_PLNPCA"])
 def test_fail_wrong_covariates_prediction(any_pln):
-    X = torch.randn(any_pln.n_samples, any_pln.nb_cov)
+    X = torch.randn(any_pln.n_samples, any_pln.nb_cov+1)
     with pytest.raises(Exception):
         any_pln.predict(X)
 
 
-@pytest.mark.parametrize("any__plnpca", all_fitted__plnpca)
+@pytest.mark.parametrize(
+    "pln", dict_fixtures["loaded_and_fitted_pln"]
+)
+@filter_models(["PLN"])
 def test_latent_var_pca(any__plnpca):
     assert any__plnpca.transform(project=False).shape == any__plnpca.counts.shape
     assert any__plnpca.transform().shape == (any__plnpca.n_samples, any__plnpca.rank)
 
 
-@pytest.mark.parametrize("any_pln_full", all_fitted_pln_full)
+@pytest.mark.parametrize(
+    "pln", dict_fixtures["loaded_and_fitted_pln"]
+)
+@filter_models(["PLN"])
 def test_latent_var_pln_full(any_pln_full):
     assert any_pln_full.transform().shape == any_pln_full.counts.shape
 
@@ -186,5 +183,4 @@ def test_wrong_rank():
     with pytest.warns(UserWarning):
         instance.fit(counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim)
 
-
 """
-- 
GitLab


From 6eaabd2ad8a7685d9a1134f0acac7847d2087c12 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Sat, 13 May 2023 00:55:01 +0200
Subject: [PATCH 21/24] finished to implement test_common and test_plnpca.
 still need to write sone setters properly.

---
 pyPLNmodels/models.py |  68 +++++++++++-----
 tests/test_common.py  |  61 ++++-----------
 tests/test_plnpca.py  | 178 ++++++++++--------------------------------
 tests/utils.py        |  15 ++++
 4 files changed, 122 insertions(+), 200 deletions(-)

diff --git a/pyPLNmodels/models.py b/pyPLNmodels/models.py
index c7b4c7a3..2497edb2 100644
--- a/pyPLNmodels/models.py
+++ b/pyPLNmodels/models.py
@@ -675,25 +675,38 @@ class PLNPCA:
 
     @property
     def covariates(self):
-        return self.models[0].covariates
+        return self.list_models[0].covariates
 
     @property
     def counts(self):
-        return self.models[0].counts
+        return self.list_models[0].counts
+
+    @counts.setter
+    def counts(self, counts):
+        counts = format_data(counts)
+        if hasattr(self, "_counts"):
+            check_dimensions_are_equal(self._counts, counts)
+        self._counts = counts
+
+    @covariates.setter
+    def covariates(self, covariates):
+        covariates = format_data(covariates)
+        # if hasattr(self,)
+        self._covariates = covariates
 
     @property
     def offsets(self):
-        return self.models[0].offsets
+        return self.list_models[0].offsets
 
     def init_models(self, ranks, dict_of_dict_initialization):
         if isinstance(ranks, (Iterable, np.ndarray)):
-            self.models = []
+            self.list_models = []
             for rank in ranks:
                 if isinstance(rank, (int, np.integer)):
                     dict_initialization = get_dict_initialization(
                         rank, dict_of_dict_initialization
                     )
-                    self.models.append(
+                    self.list_models.append(
                         _PLNPCA(
                             self._counts,
                             self._covariates,
@@ -711,7 +724,7 @@ class PLNPCA:
             dict_initialization = get_dict_initialization(
                 ranks, dict_of_dict_initialization
             )
-            self.models = [
+            self.list_models = [
                 _PLNPCA(
                     self._counts,
                     self._covariates,
@@ -727,11 +740,11 @@ class PLNPCA:
 
     @property
     def ranks(self):
-        return [model.rank for model in self.models]
+        return [model.rank for model in self.list_models]
 
     @property
     def dict_models(self):
-        return {model.rank: model for model in self.models}
+        return {model.rank: model for model in self.list_models}
 
     def print_beginning_message(self):
         return f"Adjusting {len(self.ranks)} PLN models for PCA analysis \n"
@@ -790,15 +803,15 @@ class PLNPCA:
 
     @property
     def BIC(self):
-        return {model.rank: int(model.BIC) for model in self.models}
+        return {model.rank: int(model.BIC) for model in self.list_models}
 
     @property
     def AIC(self):
-        return {model.rank: int(model.AIC) for model in self.models}
+        return {model.rank: int(model.AIC) for model in self.list_models}
 
     @property
     def loglikes(self):
-        return {model.rank: model.loglike for model in self.models}
+        return {model.rank: model.loglike for model in self.list_models}
 
     def show(self):
         bic = self.BIC
@@ -842,7 +855,7 @@ class PLNPCA:
     def save(self, path_of_directory="./", ranks=None):
         if ranks is None:
             ranks = self.ranks
-        for model in self.models:
+        for model in self.list_models:
             if model.rank in ranks:
                 model.save(path_of_directory)
 
@@ -852,14 +865,18 @@ class PLNPCA:
 
     @property
     def n_samples(self):
-        return self.models[0].n_samples
+        return self.list_models[0].n_samples
 
     @property
     def _p(self):
         return self[self.ranks[0]].p
 
+    @property
+    def models(self):
+        return self.dict_models.values()
+
     def __str__(self):
-        nb_models = len(self.models)
+        nb_models = len(self.list_models)
         delimiter = "\n" + "-" * NB_CHARACTERS_FOR_NICE_PLOT + "\n"
         to_print = delimiter
         to_print += f"Collection of {nb_models} PLNPCA models with \
@@ -1032,6 +1049,16 @@ class _PLNPCA(_PLN):
         ortho_components = torch.linalg.qr(self._components, "reduced")[0]
         return torch.mm(self.latent_variables, ortho_components).detach().cpu()
 
+    def pca_projected_latent_variables(self, n_components=None):
+        if n_components is None:
+            n_components = self.get_max_components()
+        if n_components > self.dim:
+            raise RuntimeError(
+                f"You ask more components ({n_components}) than variables ({self.dim})"
+            )
+        pca = PCA(n_components=n_components)
+        return pca.fit_transform(self.projected_latent_variables.detach().cpu())
+
     @property
     def components(self):
         return self.attribute_or_none("_components")
@@ -1041,13 +1068,16 @@ class _PLNPCA(_PLN):
         self._components = components
 
     def viz(self, ax=None, colors=None):
-        if self._rank != 2:
-            raise RuntimeError("Can't perform visualization for rank != 2.")
         if ax is None:
             ax = plt.gca()
-        proj_variables = self.projected_latent_variables
-        x = proj_variables[:, 0].cpu().numpy()
-        y = proj_variables[:, 1].cpu().numpy()
+        if self._rank < 2:
+            raise RuntimeError("Can't perform visualization for rank < 2.")
+        if self._rank > 2:
+            proj_variables = self.pca_projected_latent_variables(n_components=2)
+        if self._rank == 2:
+            proj_variables = self.projected_latent_variables.cpu().numpy()
+        x = proj_variables[:, 0]
+        y = proj_variables[:, 1]
         sns.scatterplot(x=x, y=y, hue=colors, ax=ax)
         covariances = torch.diag_embed(self._latent_var**2).detach().cpu()
         for i in range(covariances.shape[0]):
diff --git a/tests/test_common.py b/tests/test_common.py
index bfa94a47..06f3ef8e 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -1,29 +1,15 @@
 import os
-import functools
 
 import torch
 import pytest
+import pandas as pd
 
 from tests.conftest import dict_fixtures
-from tests.utils import MSE
+from tests.utils import MSE, filter_models
 
 from tests.import_data import true_sim_0cov, true_sim_2cov
 
 
-def filter_models(models_name):
-    def decorator(my_test):
-        @functools.wraps(my_test)
-        def new_test(**kwargs):
-            fixture = next(iter(kwargs.values()))
-            if type(fixture).__name__ not in models_name:
-                return None
-            return my_test(**kwargs)
-
-        return new_test
-
-    return decorator
-
-
 @pytest.mark.parametrize("any_pln", dict_fixtures["fitted_pln"])
 @filter_models(["PLN", "_PLNPCA"])
 def test_properties(any_pln):
@@ -138,8 +124,8 @@ def test_setter_with_numpy(pln):
 @pytest.mark.parametrize("pln", dict_fixtures["all_pln"])
 @filter_models(["PLN", "PLNPCA"])
 def test_setter_with_pandas(pln):
-    pd_counts = pd.DataFrame(any_pln.counts.numpy())
-    any_pln.counts = pd_counts
+    pd_counts = pd.DataFrame(pln.counts.numpy())
+    pln.counts = pd_counts
 
 
 @pytest.mark.parametrize("instance", dict_fixtures["instances"])
@@ -147,40 +133,27 @@ def test_random_init(instance):
     instance.fit(do_smart_init=False)
 
 
-"""
 @pytest.mark.parametrize("instance", dict_fixtures["instances"])
 def test_print_end_of_fitting_message(instance):
-    instance.fit(counts_sim, covariates_sim, offsets_sim, nb_max_iteration=4)
+    instance.fit(nb_max_iteration=4)
 
 
-@pytest.mark.parametrize("any_pln", all_fitted_models)
+@pytest.mark.parametrize("pln", dict_fixtures["fitted_pln"])
 @filter_models(["PLN", "_PLNPCA"])
-def test_fail_wrong_covariates_prediction(any_pln):
-    X = torch.randn(any_pln.n_samples, any_pln.nb_cov+1)
+def test_fail_wrong_covariates_prediction(pln):
+    X = torch.randn(pln.n_samples, pln.nb_cov + 1)
     with pytest.raises(Exception):
-        any_pln.predict(X)
+        pln.predict(X)
 
 
-@pytest.mark.parametrize(
-    "pln", dict_fixtures["loaded_and_fitted_pln"]
-)
-@filter_models(["PLN"])
-def test_latent_var_pca(any__plnpca):
-    assert any__plnpca.transform(project=False).shape == any__plnpca.counts.shape
-    assert any__plnpca.transform().shape == (any__plnpca.n_samples, any__plnpca.rank)
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["_PLNPCA"])
+def test_latent_var_pca(plnpca):
+    assert plnpca.transform(project=False).shape == plnpca.counts.shape
+    assert plnpca.transform().shape == (plnpca.n_samples, plnpca.rank)
 
 
-@pytest.mark.parametrize(
-    "pln", dict_fixtures["loaded_and_fitted_pln"]
-)
+@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
 @filter_models(["PLN"])
-def test_latent_var_pln_full(any_pln_full):
-    assert any_pln_full.transform().shape == any_pln_full.counts.shape
-
-
-def test_wrong_rank():
-    instance = _PLNPCA(counts_sim.shape[1] + 1)
-    with pytest.warns(UserWarning):
-        instance.fit(counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim)
-
-"""
+def test_latent_var_full(pln):
+    assert pln.transform().shape == pln.counts.shape
diff --git a/tests/test_plnpca.py b/tests/test_plnpca.py
index db0e324a..af06a2ad 100644
--- a/tests/test_plnpca.py
+++ b/tests/test_plnpca.py
@@ -1,158 +1,62 @@
 import os
 
 import pytest
-from pytest_lazyfixture import lazy_fixture as lf
-from pyPLNmodels.models import PLNPCA, _PLNPCA
-from pyPLNmodels import get_simulated_count_data, get_real_count_data
-from tests.utils import MSE
-
 import matplotlib.pyplot as plt
 import numpy as np
 
-(
-    counts_sim,
-    covariates_sim,
-    offsets_sim,
-    true_covariance,
-    true_coef,
-) = get_simulated_count_data(return_true_param=True)
-
-counts_real = get_real_count_data()
-RANKS = [2, 8]
-
-
-@pytest.fixture
-def my_instance_plnpca():
-    plnpca = PLNPCA(ranks=RANKS)
-    return plnpca
-
-
-@pytest.fixture
-def real_fitted_plnpca(my_instance_plnpca):
-    my_instance_plnpca.fit(counts_real)
-    return my_instance_plnpca
-
-
-@pytest.fixture
-def simulated_fitted_plnpca(my_instance_plnpca):
-    my_instance_plnpca.fit(
-        counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim
-    )
-    return my_instance_plnpca
-
-
-@pytest.fixture
-def one_simulated_fitted_plnpca():
-    model = PLNPCA(ranks=2)
-    model.fit(counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim)
-    return model
-
-
-@pytest.fixture
-def real_best_aic(real_fitted_plnpca):
-    return real_fitted_plnpca.best_model("AIC")
-
-
-@pytest.fixture
-def real_best_bic(real_fitted_plnpca):
-    return real_fitted_plnpca.best_model("BIC")
-
-
-@pytest.fixture
-def simulated_best_aic(simulated_fitted_plnpca):
-    return simulated_fitted_plnpca.best_model("AIC")
-
-
-@pytest.fixture
-def simulated_best_bic(simulated_fitted_plnpca):
-    return simulated_fitted_plnpca.best_model("BIC")
+from tests.conftest import dict_fixtures
+from tests.utils import MSE, filter_models
 
 
-simulated_best_models = [lf("simulated_best_aic"), lf("simulated_best_bic")]
-real_best_models = [lf("real_best_aic"), lf("real_best_bic")]
-best_models = simulated_best_models + real_best_models
-
-
-all_fitted_simulated_plnpca = [
-    lf("simulated_fitted_plnpca"),
-    lf("one_simulated_fitted_plnpca"),
-]
-all_fitted_plnpca = [lf("real_fitted_plnpca")] + all_fitted_simulated_plnpca
-
-
-def test_print_plnpca(simulated_fitted_plnpca):
-    print(simulated_fitted_plnpca)
-
-
-@pytest.mark.parametrize("best_model", best_models)
-def test_best_model(best_model):
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLNPCA"])
+def test_best_model(plnpca):
+    best_model = plnpca.best_model()
     print(best_model)
 
 
-@pytest.mark.parametrize("best_model", best_models)
-def test_projected_variables(best_model):
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLNPCA"])
+def test_projected_variables(plnpca):
+    best_model = plnpca.best_model()
     plv = best_model.projected_latent_variables
     assert plv.shape[0] == best_model.n_samples and plv.shape[1] == best_model.rank
 
 
-def test_save_load_back_and_refit(simulated_fitted_plnpca):
-    simulated_fitted_plnpca.save()
-    new = PLNPCA(ranks=RANKS)
-    new.load()
-    new.fit(counts=counts_sim, covariates=covariates_sim, offsets=offsets_sim)
-
-
-@pytest.mark.parametrize("plnpca", all_fitted_simulated_plnpca)
-def test_find_right_covariance(plnpca):
-    passed = True
-    for model in plnpca.models:
-        mse_covariance = MSE(model.covariance - true_covariance)
-        assert mse_covariance < 0.3
-
-
-@pytest.mark.parametrize("plnpca", all_fitted_simulated_plnpca)
-def test_find_right_coef(plnpca):
-    for model in plnpca.models:
-        mse_coef = MSE(model.coef - true_coef)
-        assert mse_coef < 0.3
-
-
-@pytest.mark.parametrize("all_pca", all_fitted_plnpca)
-def test_additional_methods_pca(all_pca):
-    all_pca.show()
-    all_pca.BIC
-    all_pca.AIC
-    all_pca.loglikes
-
-
-@pytest.mark.parametrize("all_pca", all_fitted_plnpca)
-def test_viz_pca(all_pca):
-    _, ax = plt.subplots()
-    all_pca[2].viz(ax=ax)
-    plt.show()
-    all_pca[2].viz()
-    plt.show()
-    n_samples = all_pca.n_samples
-    colors = np.random.randint(low=0, high=2, size=n_samples)
-    all_pca[2].viz(colors=colors)
-    plt.show()
-
-
-@pytest.mark.parametrize(
-    "pca", [lf("real_fitted_plnpca"), lf("simulated_fitted_plnpca")]
-)
-def test_fails_viz_pca(pca):
-    with pytest.raises(RuntimeError):
-        pca[8].viz()
-
-
-@pytest.mark.parametrize("all_pca", all_fitted_plnpca)
-def test_closest(all_pca):
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLNPCA"])
+def test_additional_methods_pca(plnpca):
+    plnpca.show()
+    plnpca.BIC
+    plnpca.AIC
+    plnpca.loglikes
+
+
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLNPCA"])
+def test_viz_pca(plnpca):
+    models = plnpca.models
+    for model in models:
+        _, ax = plt.subplots()
+        model.viz(ax=ax)
+        plt.show()
+        model.viz()
+        plt.show()
+        n_samples = plnpca.n_samples
+        colors = np.random.randint(low=0, high=2, size=n_samples)
+        model.viz(colors=colors)
+        plt.show()
+
+
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLNPCA"])
+def test_closest(plnpca):
     with pytest.warns(UserWarning):
-        all_pca[9]
+        plnpca[9]
 
 
-@pytest.mark.parametrize("plnpca", all_fitted_plnpca)
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLNPCA"])
 def test_wrong_criterion(plnpca):
     with pytest.raises(ValueError):
         plnpca.best_model("AIK")
diff --git a/tests/utils.py b/tests/utils.py
index 0cc7f2d7..2e33a3d0 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -1,5 +1,20 @@
 import torch
+import functools
 
 
 def MSE(t):
     return torch.mean(t**2)
+
+
+def filter_models(models_name):
+    def decorator(my_test):
+        @functools.wraps(my_test)
+        def new_test(**kwargs):
+            fixture = next(iter(kwargs.values()))
+            if type(fixture).__name__ not in models_name:
+                return None
+            return my_test(**kwargs)
+
+        return new_test
+
+    return decorator
-- 
GitLab


From 5512fa9944261aac53bfff713b03657a3331a235 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Sat, 13 May 2023 08:59:02 +0200
Subject: [PATCH 22/24] remove useless tests.

---
 tests/test_args.py     | 66 ------------------------------------------
 tests/test_common.py   |  2 ++
 tests/test_pln_full.py |  5 ----
 3 files changed, 2 insertions(+), 71 deletions(-)
 delete mode 100644 tests/test_args.py

diff --git a/tests/test_args.py b/tests/test_args.py
deleted file mode 100644
index 2708851d..00000000
--- a/tests/test_args.py
+++ /dev/null
@@ -1,66 +0,0 @@
-import os
-
-from pyPLNmodels.models import PLN, PLNPCA, _PLNPCA
-from pyPLNmodels import get_simulated_count_data, get_real_count_data
-import pytest
-from pytest_lazyfixture import lazy_fixture as lf
-import pandas as pd
-import numpy as np
-
-(
-    counts_sim,
-    covariates_sim,
-    offsets_sim,
-) = get_simulated_count_data(nb_cov=2)
-
-couts_real = get_real_count_data(n_samples=298, dim=101)
-RANKS = [2, 8]
-
-
-@pytest.fixture
-def instance_plnpca():
-    plnpca = PLNPCA(ranks=RANKS)
-    return plnpca
-
-
-@pytest.fixture
-def instance__plnpca():
-    model = _PLNPCA(rank=RANKS[0])
-    return model
-
-
-@pytest.fixture
-def instance_pln_full():
-    return PLN()
-
-
-all_instances = [lf("instance_plnpca"), lf("instance__plnpca"), lf("instance_pln_full")]
-
-
-@pytest.mark.parametrize("instance", all_instances)
-def test_pandas_init(instance):
-    instance.fit(
-        pd.DataFrame(counts_sim.numpy()),
-        pd.DataFrame(covariates_sim.numpy()),
-        pd.DataFrame(offsets_sim.numpy()),
-    )
-
-
-@pytest.mark.parametrize("instance", all_instances)
-def test_numpy_init(instance):
-    instance.fit(counts_sim.numpy(), covariates_sim.numpy(), offsets_sim.numpy())
-
-
-@pytest.mark.parametrize("sim_pln", simulated_any_pln)
-def test_only_counts(sim_pln):
-    sim_pln.fit()
-
-
-@pytest.mark.parametrize("sim_pln", simulated_any_pln)
-def test_only_counts_and_offsets(sim_pln):
-    sim_pln.fit(counts=counts_sim, offsets=offsets_sim)
-
-
-@pytest.mark.parametrize("sim_pln", simulated_any_pln)
-def test_only_Y_and_cov(sim_pln):
-    sim_pln.fit(counts=counts_sim, covariates=covariates_sim)
diff --git a/tests/test_common.py b/tests/test_common.py
index 06f3ef8e..da5e57f0 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -119,6 +119,7 @@ def test_fail_count_setter(pln):
 def test_setter_with_numpy(pln):
     np_counts = pln.counts.numpy()
     pln.counts = np_counts
+    pln.fit()
 
 
 @pytest.mark.parametrize("pln", dict_fixtures["all_pln"])
@@ -126,6 +127,7 @@ def test_setter_with_numpy(pln):
 def test_setter_with_pandas(pln):
     pd_counts = pd.DataFrame(pln.counts.numpy())
     pln.counts = pd_counts
+    pln.fit()
 
 
 @pytest.mark.parametrize("instance", dict_fixtures["instances"])
diff --git a/tests/test_pln_full.py b/tests/test_pln_full.py
index 078c5bb5..29291af7 100644
--- a/tests/test_pln_full.py
+++ b/tests/test_pln_full.py
@@ -2,8 +2,3 @@ import torch
 
 from import_fixtures_and_data import get_dict_fixtures
 from pyPLNmodels import PLN
-
-
-df = get_dict_fixtures(PLN)
-for key, fixture in df.items():
-    print(len(fixture))
-- 
GitLab


From e876ed5cfe00bb79591d354d966b010676f809d8 Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Sat, 13 May 2023 09:11:57 +0200
Subject: [PATCH 23/24] reorganized the tests file so that I test tests for
 only plnfull  in test_pln_full and tests for only plnpca inside test_plnpca

---
 tests/test_common.py   | 46 +-----------------------------------------
 tests/test_pln_full.py | 19 ++++++++++++++---
 tests/test_plnpca.py   | 14 +++++++++++++
 tests/test_setters.py  | 21 +++++++++++++++++++
 4 files changed, 52 insertions(+), 48 deletions(-)
 create mode 100644 tests/test_setters.py

diff --git a/tests/test_common.py b/tests/test_common.py
index da5e57f0..43046096 100644
--- a/tests/test_common.py
+++ b/tests/test_common.py
@@ -2,7 +2,6 @@ import os
 
 import torch
 import pytest
-import pandas as pd
 
 from tests.conftest import dict_fixtures
 from tests.utils import MSE, filter_models
@@ -92,44 +91,14 @@ def test_find_right_coef(simulated_fitted_any_pln):
         assert simulated_fitted_any_pln.coef is None
 
 
-@pytest.mark.parametrize("fitted_pln", dict_fixtures["fitted_pln"])
-@filter_models(["PLN"])
-def test_number_of_iterations_pln_full(fitted_pln):
-    nb_iterations = len(fitted_pln.elbos_list)
-    assert 50 < nb_iterations < 300
-
-
-@pytest.mark.parametrize("fitted_pln", dict_fixtures["fitted_pln"])
-@filter_models(["_PLNPCA"])
-def test_number_of_iterations_plnpca(fitted_pln):
-    nb_iterations = len(fitted_pln.elbos_list)
-    assert 100 < nb_iterations < 5000
-
-
 @pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
-@filter_models(["PLN"])
+@filter_models(["PLN", "_PLNPCA"])
 def test_fail_count_setter(pln):
     wrong_counts = torch.randint(size=(10, 5), low=0, high=10)
     with pytest.raises(Exception):
         pln.counts = wrong_counts
 
 
-@pytest.mark.parametrize("pln", dict_fixtures["all_pln"])
-@filter_models(["PLN", "PLNPCA"])
-def test_setter_with_numpy(pln):
-    np_counts = pln.counts.numpy()
-    pln.counts = np_counts
-    pln.fit()
-
-
-@pytest.mark.parametrize("pln", dict_fixtures["all_pln"])
-@filter_models(["PLN", "PLNPCA"])
-def test_setter_with_pandas(pln):
-    pd_counts = pd.DataFrame(pln.counts.numpy())
-    pln.counts = pd_counts
-    pln.fit()
-
-
 @pytest.mark.parametrize("instance", dict_fixtures["instances"])
 def test_random_init(instance):
     instance.fit(do_smart_init=False)
@@ -146,16 +115,3 @@ def test_fail_wrong_covariates_prediction(pln):
     X = torch.randn(pln.n_samples, pln.nb_cov + 1)
     with pytest.raises(Exception):
         pln.predict(X)
-
-
-@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
-@filter_models(["_PLNPCA"])
-def test_latent_var_pca(plnpca):
-    assert plnpca.transform(project=False).shape == plnpca.counts.shape
-    assert plnpca.transform().shape == (plnpca.n_samples, plnpca.rank)
-
-
-@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
-@filter_models(["PLN"])
-def test_latent_var_full(pln):
-    assert pln.transform().shape == pln.counts.shape
diff --git a/tests/test_pln_full.py b/tests/test_pln_full.py
index 29291af7..1f9d6a5d 100644
--- a/tests/test_pln_full.py
+++ b/tests/test_pln_full.py
@@ -1,4 +1,17 @@
-import torch
+import pytest
 
-from import_fixtures_and_data import get_dict_fixtures
-from pyPLNmodels import PLN
+from tests.conftest import dict_fixtures
+from tests.utils import filter_models
+
+
+@pytest.mark.parametrize("fitted_pln", dict_fixtures["fitted_pln"])
+@filter_models(["PLN"])
+def test_number_of_iterations_pln_full(fitted_pln):
+    nb_iterations = len(fitted_pln.elbos_list)
+    assert 50 < nb_iterations < 300
+
+
+@pytest.mark.parametrize("pln", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["PLN"])
+def test_latent_var_full(pln):
+    assert pln.transform().shape == pln.counts.shape
diff --git a/tests/test_plnpca.py b/tests/test_plnpca.py
index af06a2ad..9eb1b2f4 100644
--- a/tests/test_plnpca.py
+++ b/tests/test_plnpca.py
@@ -23,6 +23,20 @@ def test_projected_variables(plnpca):
     assert plv.shape[0] == best_model.n_samples and plv.shape[1] == best_model.rank
 
 
+@pytest.mark.parametrize("fitted_pln", dict_fixtures["fitted_pln"])
+@filter_models(["_PLNPCA"])
+def test_number_of_iterations_plnpca(fitted_pln):
+    nb_iterations = len(fitted_pln.elbos_list)
+    assert 100 < nb_iterations < 5000
+
+
+@pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
+@filter_models(["_PLNPCA"])
+def test_latent_var_pca(plnpca):
+    assert plnpca.transform(project=False).shape == plnpca.counts.shape
+    assert plnpca.transform().shape == (plnpca.n_samples, plnpca.rank)
+
+
 @pytest.mark.parametrize("plnpca", dict_fixtures["loaded_and_fitted_pln"])
 @filter_models(["PLNPCA"])
 def test_additional_methods_pca(plnpca):
diff --git a/tests/test_setters.py b/tests/test_setters.py
new file mode 100644
index 00000000..b1a9ba29
--- /dev/null
+++ b/tests/test_setters.py
@@ -0,0 +1,21 @@
+import pytest
+import pandas as pd
+
+from tests.conftest import dict_fixtures
+from tests.utils import MSE, filter_models
+
+
+@pytest.mark.parametrize("pln", dict_fixtures["all_pln"])
+@filter_models(["PLN", "PLNPCA"])
+def test_setter_with_numpy(pln):
+    np_counts = pln.counts.numpy()
+    pln.counts = np_counts
+    pln.fit()
+
+
+@pytest.mark.parametrize("pln", dict_fixtures["all_pln"])
+@filter_models(["PLN", "PLNPCA"])
+def test_setter_with_pandas(pln):
+    pd_counts = pd.DataFrame(pln.counts.numpy())
+    pln.counts = pd_counts
+    pln.fit()
-- 
GitLab


From ccd0a96807233fd9b9225f4adbf437e01212a51b Mon Sep 17 00:00:00 2001
From: bastien-mva <bastien.batardiere@gmail.com>
Date: Sat, 13 May 2023 10:37:03 +0200
Subject: [PATCH 24/24] change version and change the gitlab ci to get token
 only for the project.

---
 .gitlab-ci.yml | 2 +-
 setup.py       | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index a8dd3286..79e1610f 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -26,7 +26,7 @@ publish_package:
   script:
     - pip install twine
     - python setup.py bdist_wheel
-    - TWINE_PASSWORD=${TWINE_PASSWORD} TWINE_USERNAME=${TWINE_USERNAME} python -m twine upload dist/*
+    - TWINE_PASSWORD=${pypln_token} TWINE_USERNAME=${account_name} python -m twine upload dist/*
   tags:
     - docker
   only:
diff --git a/setup.py b/setup.py
index 74cc8909..9c8b2745 100644
--- a/setup.py
+++ b/setup.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 from setuptools import setup, find_packages
 
-VERSION = "0.0.37"
+VERSION = "0.0.38"
 
 with open("README.md", "r") as fh:
     long_description = fh.read()
-- 
GitLab