From cc19877ed06e4a7dd471cf1277f2b0f07727327d Mon Sep 17 00:00:00 2001
From: Jean-Benoist Leger <jbleger@hds.utc.fr>
Date: Wed, 8 Feb 2023 18:14:08 +0100
Subject: [PATCH] black: the only color you need

---
 .gitlab-ci.yml               |  12 +++
 .pre-commit-config.yaml      |  10 ++
 CONTRIBUTING.md              |  15 +++
 docs/source/conf.py          |  38 +++----
 pyPLNmodels/VEM.py           | 195 ++++++++++++++++++++---------------
 pyPLNmodels/__init__.py      |   4 +-
 pyPLNmodels/_closed_forms.py |  13 +--
 pyPLNmodels/_utils.py        |  87 ++++++++--------
 pyPLNmodels/elbos.py         |  79 ++++++++------
 setup.py                     |  41 +++++---
 test.py                      |  16 +--
 11 files changed, 303 insertions(+), 207 deletions(-)
 create mode 100644 .gitlab-ci.yml
 create mode 100644 .pre-commit-config.yaml
 create mode 100644 CONTRIBUTING.md

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
new file mode 100644
index 00000000..9b122f8e
--- /dev/null
+++ b/.gitlab-ci.yml
@@ -0,0 +1,12 @@
+
+stages:
+  - checks
+  - publish
+
+black:
+  stage: checks
+  image: registry.gitlab.com/pipeline-components/black:latest
+  script:
+    - black --check --verbose -- .
+  tags:
+    - docker
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 00000000..fe4c8217
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,10 @@
+repos:
+  - repo: https://github.com/pre-commit/pre-commit-hooks
+    rev: v4.2.0
+    hooks:
+    -   id: trailing-whitespace
+    -   id: end-of-file-fixer
+  - repo: https://github.com/psf/black
+    rev: 22.3.0
+    hooks:
+      - id: black
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 00000000..faf13f3b
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,15 @@
+# Clone the repo
+
+```
+git clone git@forgemia.inra.fr:bbatardiere/pyplnmodels
+```
+
+# Install precommit
+
+In the directory:
+
+```
+pre-commit install
+```
+
+If not found use `pip install pre-commit` before this command.
diff --git a/docs/source/conf.py b/docs/source/conf.py
index bcebd83d..37dcd460 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -12,15 +12,16 @@
 #
 import os
 import sys
-sys.path.insert(0, os.path.abspath('../../pyPLNmodels'))
+
+sys.path.insert(0, os.path.abspath("../../pyPLNmodels"))
 # -- Project information -----------------------------------------------------
 
-project = 'pyPLNmodels'
-copyright = '2023, Bastien Batardière, Julien Chiquet, Joon Kwon'
-author = 'Bastien Batardière, Julien Chiquet, Joon Kwon'
+project = "pyPLNmodels"
+copyright = "2023, Bastien Batardière, Julien Chiquet, Joon Kwon"
+author = "Bastien Batardière, Julien Chiquet, Joon Kwon"
 
 # The full version, including alpha/beta/rc tags
-release = '0.0.15'
+release = "0.0.15"
 
 
 # -- General configuration ---------------------------------------------------
@@ -28,26 +29,27 @@ release = '0.0.15'
 # Add any Sphinx extension module names here, as strings. They can be
 # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
 # ones.
-extensions = ['sphinx.ext.autodoc', 
-              'sphinx.ext.intersphinx',
-              'sphinx.ext.napoleon', 
+extensions = [
+    "sphinx.ext.autodoc",
+    "sphinx.ext.intersphinx",
+    "sphinx.ext.napoleon",
 ]
-napoleon_google_docstring = False 
-napoleon_numpy_docstring = True 
+napoleon_google_docstring = False
+napoleon_numpy_docstring = True
 
 intersphinx_mappings = {
-            'python': ('https://docs.python.org/', None),
-            'numpy': ('http://docs.scipy.org/doc/numpy/', None)
-        }
+    "python": ("https://docs.python.org/", None),
+    "numpy": ("http://docs.scipy.org/doc/numpy/", None),
+}
 # Add any paths that contain templates here, relative to this directory.
-templates_path = ['_templates']
-html_theme = 'sphinx_rtd_theme'
+templates_path = ["_templates"]
+html_theme = "sphinx_rtd_theme"
 # The language for content autogenerated by Sphinx. Refer to documentation
 # for a list of supported languages.
 #
 # This is also used if you do content translation via gettext catalogs.
 # Usually you set "language" from the command line for these cases.
-language = 'en'
+language = "en"
 
 # List of patterns, relative to source directory, that match files and
 # directories to ignore when looking for source files.
@@ -60,9 +62,9 @@ exclude_patterns = []
 # The theme to use for HTML and HTML Help pages.  See the documentation for
 # a list of builtin themes.
 #
-html_theme = 'alabaster'
+html_theme = "alabaster"
 
 # Add any paths that contain custom static files (such as style sheets) here,
 # relative to this directory. They are copied after the builtin static files,
 # so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ['_static']
+html_static_path = ["_static"]
diff --git a/pyPLNmodels/VEM.py b/pyPLNmodels/VEM.py
index 7113836d..1c822df9 100644
--- a/pyPLNmodels/VEM.py
+++ b/pyPLNmodels/VEM.py
@@ -9,27 +9,28 @@ import matplotlib.pyplot as plt
 
 from ._closed_forms import closed_formula_beta, closed_formula_Sigma, closed_formula_pi
 from .elbos import ELBOPLN, ELBOPLNPCA, ELBOZIPLN, profiledELBOPLN
-from ._utils import PLNPlotArgs , init_Sigma, init_C, init_beta, getOFromSumOfY
+from ._utils import PLNPlotArgs, init_Sigma, init_C, init_beta, getOFromSumOfY
 
 if torch.cuda.is_available():
-    device = 'cuda'
+    device = "cuda"
 else:
-    device = 'cpu'
+    device = "cpu"
 
 # shoudl add a good init for M. for plnpca we should not put the maximum of the log posterior, for plnpca it may be ok.
 
 
 class _PLN(ABC):
-    """ 
-    Virtual class for all the PLN models. 
-    
+    """
+    Virtual class for all the PLN models.
+
     This class must be derivatived. The methods `get_Sigma`, `compute_ELBO`,
     `random_init_var_parameters` and `list_of_parameters_needing_gradient` must
-    be defined. 
+    be defined.
     """
+
     def __init__(self):
         """
-        Simple initialization method. 
+        Simple initialization method.
         """
         self.window = 3
         self.fitted = False
@@ -41,7 +42,7 @@ class _PLN(ABC):
         else:
             self.covariates = self.format_data(covariates)
         if O is None:
-            if O_formula == 'sum':
+            if O_formula == "sum":
                 self.O = torch.log(getOFromSumOfY(self.Y)).float()
             else:
                 self.O = torch.zeros(self.Y.shape)
@@ -67,19 +68,19 @@ class _PLN(ABC):
             return data
         else:
             raise Exception(
-                'Please insert either a numpy array, pandas.DataFrame or torch.tensor'
+                "Please insert either a numpy array, pandas.DataFrame or torch.tensor"
             )
 
-    def init_parameters(self, Y, covariates,O, doGoodInit):
+    def init_parameters(self, Y, covariates, O, doGoodInit):
         self.n, self.p = self.Y.shape
         self.d = self.covariates.shape[1]
-        print('Initialization ...')
+        print("Initialization ...")
         if doGoodInit:
             self.smart_init_model_parameters()
         else:
             self.random_init_model_parameters()
         self.random_init_var_parameters()
-        print('Initialization finished')
+        print("Initialization finished")
         self.putParametersToDevice()
 
     def putParametersToDevice(self):
@@ -90,38 +91,38 @@ class _PLN(ABC):
     def list_of_parameters_needing_gradient(self):
         pass
 
-
-    def fit(self,
-            Y,
-            covariates = None,
-            O = None,
-            nb_max_iteration=15000,
-            lr=0.01,
-            class_optimizer=torch.optim.Rprop,
-            tol=1e-3,
-            doGoodInit=True,
-            verbose=False,
-            O_formula = 'sum'):
-        """  
-        Main function of the class. Fit a PLN to the data.  
+    def fit(
+        self,
+        Y,
+        covariates=None,
+        O=None,
+        nb_max_iteration=15000,
+        lr=0.01,
+        class_optimizer=torch.optim.Rprop,
+        tol=1e-3,
+        doGoodInit=True,
+        verbose=False,
+        O_formula="sum",
+    ):
+        """
+        Main function of the class. Fit a PLN to the data.
         Parameters
         ----------
         Y : torch.tensor or ndarray or DataFrame.
-            2-d count data. 
+            2-d count data.
         covariates : torch.tensor or ndarray or DataFrame or None, default = None
-            If not `None`, the first dimension should equal the first dimension of `Y`. 
+            If not `None`, the first dimension should equal the first dimension of `Y`.
         O : torch.tensor or ndarray or DataFrame or None, default = None
-            Model offset. If not `None`, size should be the same as `Y`. 
+            Model offset. If not `None`, size should be the same as `Y`.
         """
         self.t0 = time.time()
         if self.fitted == False:
             self.plotargs = PLNPlotArgs(self.window)
-            self.format_datas(Y,covariates,O, O_formula)
-            self.init_parameters(Y, covariates,O, doGoodInit)
-        else: 
+            self.format_datas(Y, covariates, O, O_formula)
+            self.init_parameters(Y, covariates, O, doGoodInit)
+        else:
             self.t0 -= self.plotargs.running_times[-1]
-        self.optim = class_optimizer(
-            self.list_of_parameters_needing_gradient, lr=lr)
+        self.optim = class_optimizer(self.list_of_parameters_needing_gradient, lr=lr)
         nb_iteration_done = 0
         stop_condition = False
         while nb_iteration_done < nb_max_iteration and stop_condition == False:
@@ -145,25 +146,33 @@ class _PLN(ABC):
 
     def print_end_of_fitting_message(self, stop_condition, tol):
         if stop_condition:
-            print('Tolerance {} reached in {} iterations'.format(
-                tol, self.plotargs.iteration_number))
+            print(
+                "Tolerance {} reached in {} iterations".format(
+                    tol, self.plotargs.iteration_number
+                )
+            )
         else:
-            print('Maximum number of iterations reached : ',
-                  self.plotargs.iteration_number, 'last criterion = ',
-                  np.round(self.plotargs.criterions[-1], 8))
+            print(
+                "Maximum number of iterations reached : ",
+                self.plotargs.iteration_number,
+                "last criterion = ",
+                np.round(self.plotargs.criterions[-1], 8),
+            )
 
     def print_stats(self):
-        print('-------UPDATE-------')
-        print('Iteration number: ', self.plotargs.iteration_number)
-        print('Criterion: ', np.round(self.plotargs.criterions[-1], 8))
-        print('ELBO:', np.round(self.plotargs.ELBOs_list[-1], 6))
+        print("-------UPDATE-------")
+        print("Iteration number: ", self.plotargs.iteration_number)
+        print("Criterion: ", np.round(self.plotargs.criterions[-1], 8))
+        print("ELBO:", np.round(self.plotargs.ELBOs_list[-1], 6))
 
     def compute_criterion_and_update_plotargs(self, loss, tol):
         self.plotargs.ELBOs_list.append(-loss.item() / self.n)
         self.plotargs.running_times.append(time.time() - self.t0)
         if self.plotargs.iteration_number > self.window:
-            criterion = abs(self.plotargs.ELBOs_list[-1] -
-                        self.plotargs.ELBOs_list[-1 - self.window])
+            criterion = abs(
+                self.plotargs.ELBOs_list[-1]
+                - self.plotargs.ELBOs_list[-1 - self.window]
+            )
             self.plotargs.criterions.append(criterion)
             return criterion
         else:
@@ -176,21 +185,21 @@ class _PLN(ABC):
     def compute_ELBO(self):
         pass
 
-    def display_Sigma(self, ax=None, savefig=False, name_file=''):
+    def display_Sigma(self, ax=None, savefig=False, name_file=""):
         """
-        Display a heatmap of Sigma to visualize correlations. 
+        Display a heatmap of Sigma to visualize correlations.
 
-        If Sigma is too big (size is > 400), will only display the first block 
-        of size (400,400). 
-        Parameters 
+        If Sigma is too big (size is > 400), will only display the first block
+        of size (400,400).
+        Parameters
         ---------
-        
+
         ax : matplotlib Axes, optional
             Axes in which to draw the plot, otherwise use the currently-active Axes.
-        savefig: bool, optional 
-            If True the figure will be saved. Default is False. 
+        savefig: bool, optional
+            If True the figure will be saved. Default is False.
         name_file : str, optional
-            The name of the file the graphic will be saved to if saved. 
+            The name of the file the graphic will be saved to if saved.
             Default is an empty string.
         """
         fig = plt.figure()
@@ -203,30 +212,32 @@ class _PLN(ABC):
         plt.close()  # to avoid displaying a blanck screen
 
     def __str__(self):
-        string ='A multivariate Poisson Lognormal with ' + self.DESCRIPTION + '\n' 
-        string += 'Best likelihood:'+ str(np.max(-self.plotargs.ELBOs_list[-1])) + '\n'
+        string = "A multivariate Poisson Lognormal with " + self.DESCRIPTION + "\n"
+        string += "Best likelihood:" + str(np.max(-self.plotargs.ELBOs_list[-1])) + "\n"
         return string
 
     def show(self):
-        print('Best likelihood:',np.max(-self.plotargs.ELBOs_list[-1]))
+        print("Best likelihood:", np.max(-self.plotargs.ELBOs_list[-1]))
         fig, axes = plt.subplots(1, 3, figsize=(23, 5))
         self.plotargs.show_loss(ax=axes[0])
         self.plotargs.show_stopping_criterion(ax=axes[1])
         self.display_Sigma(ax=axes[2])
         plt.show()
-        return ''
+        return ""
+
     @abstractmethod
     def get_Sigma(self):
         pass
-    
+
     @property
     def ELBOs_list(self):
         return self.plotargs.ELBOs_list
 
 
 class PLN(_PLN):
-    NAME = 'PLN'
-    DESCRIPTION= 'full covariance model.'
+    NAME = "PLN"
+    DESCRIPTION = "full covariance model."
+
     def random_init_var_parameters(self):
         self.S = 1 / 2 * torch.ones((self.n, self.p)).to(device)
         self.M = torch.ones((self.n, self.p)).to(device)
@@ -239,8 +250,14 @@ class PLN(_PLN):
         return profiledELBOPLN(self.Y, self.covariates, self.O, self.M, self.S)
 
     def get_Sigma(self):
-        return closed_formula_Sigma(self.covariates, self.M, self.S, self.get_beta(),
-                                 self.n).detach().cpu()
+        return (
+            closed_formula_Sigma(
+                self.covariates, self.M, self.S, self.get_beta(), self.n
+            )
+            .detach()
+            .cpu()
+        )
+
     def smart_init_model_parameters(self):
         pass
 
@@ -250,16 +267,17 @@ class PLN(_PLN):
     def get_beta(self):
         return closed_formula_beta(self.covariates, self.M).detach().cpu()
 
-    @property 
-    def beta(self): 
+    @property
+    def beta(self):
         return self.get_beta()
 
-    @property 
+    @property
     def Sigma(self):
         return self.get_Sigma()
 
+
 class PLNPCA(_PLN):
-    NAME = 'PLNPCA'
+    NAME = "PLNPCA"
     DESCRIPTION = " with Principal Component Analysis."
 
     def __init__(self, q):
@@ -283,16 +301,17 @@ class PLNPCA(_PLN):
         return [self.C, self.beta, self.M, self.S]
 
     def compute_ELBO(self):
-        return ELBOPLNPCA(self.Y, self.covariates, self.O, self.M, self.S, self.C,
-                       self.beta)
+        return ELBOPLNPCA(
+            self.Y, self.covariates, self.O, self.M, self.S, self.C, self.beta
+        )
 
     def get_Sigma(self):
         return (self.C @ (self.C.T)).detach().cpu()
 
-    
+
 class ZIPLN(PLN):
-    NAME = 'ZIPLN'
-    DESCRIPTION= 'with full covariance model and zero-inflation.'
+    NAME = "ZIPLN"
+    DESCRIPTION = "with full covariance model and zero-inflation."
 
     def random_init_model_parameters(self):
         super().random_init_model_parameters()
@@ -306,16 +325,24 @@ class ZIPLN(PLN):
         self.Theta_zero = torch.randn(self.d, self.p)
 
     def random_init_var_parameters(self):
-        self.dirac = (self.Y == 0)
+        self.dirac = self.Y == 0
         self.M = torch.randn(self.n, self.p)
         self.S = torch.randn(self.n, self.p)
-        self.pi = torch.empty(self.n, self.p).uniform_(
-            0, 1).to(device) * self.dirac
+        self.pi = torch.empty(self.n, self.p).uniform_(0, 1).to(device) * self.dirac
 
     def compute_ELBO(self):
-        return ELBOZIPLN(self.Y, self.covariates, self.O, self.M, self.S,
-                      self.pi, self.Sigma, self.beta, self.Theta_zero,
-                      self.dirac)
+        return ELBOZIPLN(
+            self.Y,
+            self.covariates,
+            self.O,
+            self.M,
+            self.S,
+            self.pi,
+            self.Sigma,
+            self.beta,
+            self.Theta_zero,
+            self.dirac,
+        )
 
     def get_Sigma(self):
         return self.Sigma.detach().cpu()
@@ -326,7 +353,9 @@ class ZIPLN(PLN):
 
     def update_closed_forms(self):
         self.beta = closed_formula_beta(self.covariates, self.M)
-        self.Sigma = closed_formula_Sigma(self.covariates, self.M, self.S, self.beta,
-                                 self.n)
-        self.pi = closed_formula_pi(self.O, self.M, self.S, self.dirac, self.covariates,
-                           self.Theta_zero)
+        self.Sigma = closed_formula_Sigma(
+            self.covariates, self.M, self.S, self.beta, self.n
+        )
+        self.pi = closed_formula_pi(
+            self.O, self.M, self.S, self.dirac, self.covariates, self.Theta_zero
+        )
diff --git a/pyPLNmodels/__init__.py b/pyPLNmodels/__init__.py
index 8e9c99cb..ecc50736 100644
--- a/pyPLNmodels/__init__.py
+++ b/pyPLNmodels/__init__.py
@@ -1,4 +1,4 @@
-__version__ = '0.0.15'
+__version__ = "0.0.15"
 
-from .VEM import (PLNPCA, PLN)
+from .VEM import PLNPCA, PLN
 from .elbos import profiledELBOPLN, ELBOPLNPCA, ELBOPLN
diff --git a/pyPLNmodels/_closed_forms.py b/pyPLNmodels/_closed_forms.py
index 3e6d1ce7..d9d7dd9c 100644
--- a/pyPLNmodels/_closed_forms.py
+++ b/pyPLNmodels/_closed_forms.py
@@ -12,13 +12,10 @@ def closed_formula_Sigma(covariates, M, S, beta, n):
 def closed_formula_beta(covariates, M):
     """Closed form for beta for the M step for the noPCA model."""
     return torch.mm(
-        torch.mm(
-            torch.inverse(torch.mm(
-                covariates.T,
-                covariates)),
-            covariates.T),
-        M)
+        torch.mm(torch.inverse(torch.mm(covariates.T, covariates)), covariates.T), M
+    )
+
 
 def closed_formula_pi(O, M, S, dirac, covariates, Theta_zero):
-    A = torch.exp(O+M+torch.multiply(S, S)/2)
-    return  torch.multiply(torch.sigmoid(A+torch.mm(covariates, Theta_zero)), dirac)
+    A = torch.exp(O + M + torch.multiply(S, S) / 2)
+    return torch.multiply(torch.sigmoid(A + torch.mm(covariates, Theta_zero)), dirac)
diff --git a/pyPLNmodels/_utils.py b/pyPLNmodels/_utils.py
index e56ba35f..90ca5dd7 100644
--- a/pyPLNmodels/_utils.py
+++ b/pyPLNmodels/_utils.py
@@ -23,7 +23,6 @@ class PLNPlotArgs:
         self.criterions = [1] * window
         self.ELBOs_list = list()
 
-
     @property
     def iteration_number(self):
         return len(self.ELBOs_list)
@@ -45,8 +44,9 @@ class PLNPlotArgs:
             -np.array(self.ELBOs_list),
             label="Negative ELBO",
         )
-        ax.set_title("Negative ELBO. Best ELBO = " +
-                     str(np.round(self.ELBOs_list[-1], 6)))
+        ax.set_title(
+            "Negative ELBO. Best ELBO = " + str(np.round(self.ELBOs_list[-1], 6))
+        )
         ax.set_yscale("log")
         ax.set_xlabel("Seconds")
         ax.set_ylabel("ELBO")
@@ -67,9 +67,11 @@ class PLNPlotArgs:
         """
         if ax is None:
             ax = plt.gca()
-        ax.plot(self.running_times[self.window:],
-                self.criterions[self.window:],
-                label="Delta")
+        ax.plot(
+            self.running_times[self.window :],
+            self.criterions[self.window :],
+            label="Delta",
+        )
         ax.set_yscale("log")
         ax.set_xlabel("Seconds")
         ax.set_ylabel("Delta")
@@ -81,14 +83,7 @@ class PLNPlotArgs:
 
 
 class PoissonRegressor:
-    def fit(self,
-            Y,
-            covariates,
-            O,
-            Niter_max=300,
-            tol=0.001,
-            lr=0.005,
-            verbose=False):
+    def fit(self, Y, covariates, O, Niter_max=300, tol=0.001, lr=0.005, verbose=False):
         """Run a gradient ascent to maximize the log likelihood, using
         pytorch autodifferentiation. The log likelihood considered is
         the one from a poisson regression model. It is roughly the
@@ -111,14 +106,14 @@ class PoissonRegressor:
                 by calling self.beta.
         """
         # Initialization of beta of size (d,p)
-        beta = torch.rand((covariates.shape[1], Y.shape[1]),
-                          device=device,
-                          requires_grad=True)
+        beta = torch.rand(
+            (covariates.shape[1], Y.shape[1]), device=device, requires_grad=True
+        )
         optimizer = torch.optim.Rprop([beta], lr=lr)
         i = 0
         gradNorm = 2 * tol  # Criterion
         while i < Niter_max and gradNorm > tol:
-            loss = -poissreg_loglike(Y,  covariates,O, beta)
+            loss = -poissreg_loglike(Y, covariates, O, beta)
             loss.backward()
             optimizer.step()
             gradNorm = torch.norm(beta.grad)
@@ -136,8 +131,8 @@ class PoissonRegressor:
         self.beta = beta
 
 
-def init_Sigma(Y,  covariates,O, beta):
-    """ Initialization for Sigma for the PLN model. Take the log of Y
+def init_Sigma(Y, covariates, O, beta):
+    """Initialization for Sigma for the PLN model. Take the log of Y
     (careful when Y=0), remove the covariates effects X@beta and
     then do as a MLE for Gaussians samples.
     Args :
@@ -151,8 +146,9 @@ def init_Sigma(Y,  covariates,O, beta):
     # then we set the log(Y) as 0.
     log_Y = torch.log(Y + (Y == 0) * math.exp(-2))
     # we remove the mean so that we see only the covariances
-    log_Y_centered = log_Y - \
-        torch.matmul(covariates.unsqueeze(1), beta.unsqueeze(0)).squeeze()
+    log_Y_centered = (
+        log_Y - torch.matmul(covariates.unsqueeze(1), beta.unsqueeze(0)).squeeze()
+    )
     # MLE in a Gaussian setting
     n = Y.shape[0]
     Sigma_hat = 1 / (n - 1) * (log_Y_centered.T) @ log_Y_centered
@@ -160,7 +156,7 @@ def init_Sigma(Y,  covariates,O, beta):
     return Sigma_hat
 
 
-def init_C(Y,  covariates,O, beta, q):
+def init_C(Y, covariates, O, beta, q):
     """Inititalization for C for the PLN model. Get a first
     guess for Sigma that is easier to estimate and then takes
     the q largest eigenvectors to get C.
@@ -174,13 +170,13 @@ def init_C(Y,  covariates,O, beta, q):
         torch.tensor of size (p,q). The initialization of C.
     """
     # get a guess for Sigma
-    Sigma_hat = init_Sigma(Y,  covariates,O, beta).detach()
+    Sigma_hat = init_Sigma(Y, covariates, O, beta).detach()
     # taking the q largest eigenvectors
     C = C_from_Sigma(Sigma_hat, q)
     return C
 
 
-def init_M(Y, covariates,O, beta, C, N_iter_max, lr, eps=7e-3):
+def init_M(Y, covariates, O, beta, C, N_iter_max, lr, eps=7e-3):
     """Initialization for the variational parameter M. Basically,
     the mode of the log_posterior is computed.
 
@@ -200,12 +196,12 @@ def init_M(Y, covariates,O, beta, C, N_iter_max, lr, eps=7e-3):
     W = torch.randn(Y.shape[0], C.shape[1], device=device)
     W.requires_grad_(True)
     optimizer = torch.optim.Rprop([W], lr=lr)
-    crit= 2 * eps
+    crit = 2 * eps
     old_W = torch.clone(W)
     keep_condition = True
     i = 0
     while i < N_iter_max and keep_condition:
-        loss = -torch.mean(log_PW_given_Y(Y, covariates,O, W, C, beta))
+        loss = -torch.mean(log_PW_given_Y(Y, covariates, O, W, C, beta))
         loss.backward()
         optimizer.step()
         crit = torch.max(torch.abs(W - old_W))
@@ -218,7 +214,7 @@ def init_M(Y, covariates,O, beta, C, N_iter_max, lr, eps=7e-3):
     return W
 
 
-def poissreg_loglike(Y, covariates,O, beta):
+def poissreg_loglike(Y, covariates, O, beta):
     """Compute the log likelihood of a Poisson regression."""
     # Matrix multiplication of X and beta.
     XB = torch.matmul(covariates.unsqueeze(1), beta.unsqueeze(0)).squeeze()
@@ -231,7 +227,7 @@ def sigmoid(x):
     return 1 / (1 + torch.exp(-x))
 
 
-def sample_PLN(C, beta, covariates,O, B_zero=None):
+def sample_PLN(C, beta, covariates, O, B_zero=None):
     """Sample Poisson log Normal variables. If B_zero is not None, the model will
     be zero inflated.
 
@@ -282,17 +278,19 @@ def build_block_Sigma(p, block_size):
     # np.random.seed(0)
     k = p // block_size  # number of matrices of size p//block_size.
     # will multiply each block by some random quantities
-    alea = np.random.randn(k + 1)**2 + 1
+    alea = np.random.randn(k + 1) ** 2 + 1
     Sigma = np.zeros((p, p))
     last_block_size = p - k * block_size
     # We need to form the k matrics of size p//block_size
     for i in range(k):
-        Sigma[i * block_size:(i + 1) * block_size, i * block_size:(i + 1) *
-              block_size] = alea[i] * toeplitz(0.7**np.arange(block_size))
+        Sigma[
+            i * block_size : (i + 1) * block_size, i * block_size : (i + 1) * block_size
+        ] = alea[i] * toeplitz(0.7 ** np.arange(block_size))
     # Last block matrix.
     if last_block_size > 0:
         Sigma[-last_block_size:, -last_block_size:] = alea[k] * toeplitz(
-            0.7**np.arange(last_block_size))
+            0.7 ** np.arange(last_block_size)
+        )
     return Sigma
 
 
@@ -314,7 +312,7 @@ def C_from_Sigma(Sigma, q):
 
 def init_beta(Y, covariates, O):
     poiss_reg = PoissonRegressor()
-    poiss_reg.fit(Y, covariates,O)
+    poiss_reg.fit(Y, covariates, O)
     return torch.clone(poiss_reg.beta.detach()).to(device)
 
 
@@ -326,13 +324,13 @@ def log_stirling(n):
     Returns:
         An approximation of log(n_!) element-wise.
     """
-    n_ = n + \
-        (n == 0)  # Replace the 0 with 1. It doesn't change anything since 0! = 1!
-    return torch.log(torch.sqrt(
-        2 * np.pi * n_)) + n_ * torch.log(n_ / math.exp(1))  # Stirling formula
+    n_ = n + (n == 0)  # Replace the 0 with 1. It doesn't change anything since 0! = 1!
+    return torch.log(torch.sqrt(2 * np.pi * n_)) + n_ * torch.log(
+        n_ / math.exp(1)
+    )  # Stirling formula
 
 
-def log_PW_given_Y(Y_b, covariates_b,O_b, W, C, beta):
+def log_PW_given_Y(Y_b, covariates_b, O_b, W, C, beta):
     """Compute the log posterior of the PLN model. Compute it either
     for W of size (N_samples, N_batch,q) or (batch_size, q). Need to have
     both cases since it is done for both cases after. Please the mathematical
@@ -347,14 +345,11 @@ def log_PW_given_Y(Y_b, covariates_b,O_b, W, C, beta):
     if length == 2:
         CW = torch.matmul(C.unsqueeze(0), W.unsqueeze(2)).squeeze()
     elif length == 3:
-        CW = torch.matmul(C.unsqueeze(0).unsqueeze(1),
-                          W.unsqueeze(3)).squeeze()
+        CW = torch.matmul(C.unsqueeze(0).unsqueeze(1), W.unsqueeze(3)).squeeze()
 
     A_b = O_b + CW + covariates_b @ beta
-    first_term = -q / 2 * math.log(2 * math.pi) - \
-        1 / 2 * torch.norm(W, dim=-1) ** 2
-    second_term = torch.sum(-torch.exp(A_b) + A_b * Y_b - log_stirling(Y_b),
-                            axis=-1)
+    first_term = -q / 2 * math.log(2 * math.pi) - 1 / 2 * torch.norm(W, dim=-1) ** 2
+    second_term = torch.sum(-torch.exp(A_b) + A_b * Y_b - log_stirling(Y_b), axis=-1)
     return first_term + second_term
 
 
@@ -370,5 +365,5 @@ def trunc_log(x, eps=1e-16):
 
 
 def getOFromSumOfY(Y):
-    sumOfY = torch.sum(Y, axis = 1)
+    sumOfY = torch.sum(Y, axis=1)
     return sumOfY.repeat((Y.shape[1], 1)).T
diff --git a/pyPLNmodels/elbos.py b/pyPLNmodels/elbos.py
index a135051d..cc6ea1c0 100644
--- a/pyPLNmodels/elbos.py
+++ b/pyPLNmodels/elbos.py
@@ -3,7 +3,7 @@ from ._utils import log_stirling, trunc_log
 from ._closed_forms import closed_formula_Sigma, closed_formula_beta
 
 
-def ELBOPLN(Y, covariates,O, M, S, Sigma, beta):
+def ELBOPLN(Y, covariates, O, M, S, Sigma, beta):
     """
     Compute the ELBO (Evidence LOwer Bound) for the PLN model. See the doc for more details
     on the computation.
@@ -25,19 +25,23 @@ def ELBOPLN(Y, covariates,O, M, S, Sigma, beta):
     MmoinsXB = M - torch.mm(covariates, beta)
     elbo = -n / 2 * torch.logdet(Sigma)
     elbo += torch.sum(
-        torch.multiply(Y, OplusM) - torch.exp(OplusM + SrondS / 2) +
-        1 / 2 * torch.log(SrondS))
+        torch.multiply(Y, OplusM)
+        - torch.exp(OplusM + SrondS / 2)
+        + 1 / 2 * torch.log(SrondS)
+    )
     DplusMmoinsXB2 = torch.diag(torch.sum(SrondS, dim=0)) + torch.mm(
-        MmoinsXB.T, MmoinsXB)
+        MmoinsXB.T, MmoinsXB
+    )
     moinspsur2n = 1 / 2 * torch.trace(torch.mm(torch.inverse(Sigma), DplusMmoinsXB2))
     elbo -= 1 / 2 * torch.trace(torch.mm(torch.inverse(Sigma), DplusMmoinsXB2))
     elbo -= torch.sum(log_stirling(Y))
     elbo += n * p / 2
     return elbo
 
-def profiledELBOPLN(Y,covariates,O,M,S): 
+
+def profiledELBOPLN(Y, covariates, O, M, S):
     """
-    Compute the ELBO (Evidence LOwer Bound) for the PLN model. We use the fact that Sigma and beta are 
+    Compute the ELBO (Evidence LOwer Bound) for the PLN model. We use the fact that Sigma and beta are
     completely determined by M,S, and the covariates. See the doc for more details
     on the computation.
 
@@ -55,17 +59,19 @@ def profiledELBOPLN(Y,covariates,O,M,S):
     n, p = Y.shape
     SrondS = torch.multiply(S, S)
     OplusM = O + M
-    closed_beta = closed_formula_beta(covariates, M) 
+    closed_beta = closed_formula_beta(covariates, M)
     closed_Sigma = closed_formula_Sigma(covariates, M, S, closed_beta, n)
-    elbo = -n/2*torch.logdet(closed_Sigma)
+    elbo = -n / 2 * torch.logdet(closed_Sigma)
     elbo += torch.sum(
-        torch.multiply(Y, OplusM) - torch.exp(OplusM + SrondS / 2) +
-        1 / 2 * torch.log(SrondS))
+        torch.multiply(Y, OplusM)
+        - torch.exp(OplusM + SrondS / 2)
+        + 1 / 2 * torch.log(SrondS)
+    )
     elbo -= torch.sum(log_stirling(Y))
     return elbo
 
 
-def ELBOPLNPCA(Y, covariates,O, M, S, C, beta):
+def ELBOPLNPCA(Y, covariates, O, M, S, C, beta):
     """
     Compute the ELBO (Evidence LOwer Bound) for the PLN model with a PCA
     parametrization. See the doc for more details on the computation.
@@ -87,16 +93,23 @@ def ELBOPLNPCA(Y, covariates,O, M, S, C, beta):
     SrondS = torch.multiply(S, S)
     YA = torch.sum(torch.multiply(Y, A))
     moinsexpAplusSrondSCCT = torch.sum(
-        -torch.exp(A + 1 / 2 * torch.mm(SrondS,
-                                        torch.multiply(C, C).T)))
+        -torch.exp(A + 1 / 2 * torch.mm(SrondS, torch.multiply(C, C).T))
+    )
     moinslogSrondS = 1 / 2 * torch.sum(torch.log(SrondS))
-    MMplusSrondS = torch.sum(-1 / 2 *
-                             (torch.multiply(M, M) + torch.multiply(S, S)))
+    MMplusSrondS = torch.sum(-1 / 2 * (torch.multiply(M, M) + torch.multiply(S, S)))
     log_stirlingY = torch.sum(log_stirling(Y))
-    return YA + moinsexpAplusSrondSCCT + moinslogSrondS + MMplusSrondS - log_stirlingY + n * q / 2
+    return (
+        YA
+        + moinsexpAplusSrondSCCT
+        + moinslogSrondS
+        + MMplusSrondS
+        - log_stirlingY
+        + n * q / 2
+    )
+
 
 ## should rename some variables so that is is clearer when we see the formula
-def ELBOZIPLN(Y, covariates,O, M, S, pi, Sigma, beta, B_zero, dirac):
+def ELBOZIPLN(Y, covariates, O, M, S, pi, Sigma, beta, B_zero, dirac):
     """Compute the ELBO (Evidence LOwer Bound) for the Zero Inflated PLN model.
     See the doc for more details on the computation.
 
@@ -114,7 +127,7 @@ def ELBOZIPLN(Y, covariates,O, M, S, pi, Sigma, beta, B_zero, dirac):
         torch.tensor of size 1 with a gradient.
     """
     if torch.norm(pi * dirac - pi) > 0.0001:
-        print('Bug')
+        print("Bug")
         return False
     n = Y.shape[0]
     p = Y.shape[1]
@@ -125,20 +138,28 @@ def ELBOZIPLN(Y, covariates,O, M, S, pi, Sigma, beta, B_zero, dirac):
     elbo = torch.sum(
         torch.multiply(
             1 - pi,
-            torch.multiply(Y, OplusM) - torch.exp(OplusM + SrondS / 2) -
-            log_stirling(Y)) + pi)
+            torch.multiply(Y, OplusM)
+            - torch.exp(OplusM + SrondS / 2)
+            - log_stirling(Y),
+        )
+        + pi
+    )
 
     elbo -= torch.sum(
-        torch.multiply(pi, trunc_log(pi)) +
-        torch.multiply(1 - pi, trunc_log(1 - pi)))
-    elbo += torch.sum(
-        torch.multiply(pi, XB_zero) - torch.log(1 + torch.exp(XB_zero)))
+        torch.multiply(pi, trunc_log(pi)) + torch.multiply(1 - pi, trunc_log(1 - pi))
+    )
+    elbo += torch.sum(torch.multiply(pi, XB_zero) - torch.log(1 + torch.exp(XB_zero)))
 
-    elbo -= 1 / 2 * torch.trace(
-        torch.mm(
-            torch.inverse(Sigma),
-            torch.diag(torch.sum(SrondS, dim=0)) +
-            torch.mm(MmoinsXB.T, MmoinsXB)))
+    elbo -= (
+        1
+        / 2
+        * torch.trace(
+            torch.mm(
+                torch.inverse(Sigma),
+                torch.diag(torch.sum(SrondS, dim=0)) + torch.mm(MmoinsXB.T, MmoinsXB),
+            )
+        )
+    )
     elbo += n / 2 * torch.log(torch.det(Sigma))
     elbo += n * p / 2
     elbo += torch.sum(1 / 2 * torch.log(SrondS))
diff --git a/setup.py b/setup.py
index 86781ef8..92a7c604 100644
--- a/setup.py
+++ b/setup.py
@@ -8,25 +8,37 @@ with open("requirements.txt", "r") as fh:
     requirements = [line.strip() for line in fh]
 
 setup(
-    name='pyPLNmodels',
+    name="pyPLNmodels",
     version=__version__,
-    description = 'Package implementing PLN models',
-    url=,
+    description="Package implementing PLN models",
     project_urls={
-        "Source": 'https://github.com/PLN-team/PLNpy/tree/master/pyPLNmodels', 
-        }
-    author='Bastien Batardière, Julien Chiquet, Joon Kwon',
-    author_email='bastien.batardiere@gmail.com, julien.chiquet@inrae.fr, joon.kwon@inrae.fr',
-    license_files = ('LICENSE.txt',),
+        "Source": "https://github.com/PLN-team/PLNpy/tree/master/pyPLNmodels",
+    },
+    author="Bastien Batardière, Julien Chiquet, Joon Kwon",
+    author_email="bastien.batardiere@gmail.com, julien.chiquet@inrae.fr, joon.kwon@inrae.fr",
+    license_files=("LICENSE.txt",),
     long_description=long_description,
     packages=find_packages(),
-    python_requires='>=3',
-    keywords=['python','count', 'data', 'count data', 'high dimension', 'scRNAseq', 'PLN'],
+    python_requires=">=3",
+    keywords=[
+        "python",
+        "count",
+        "data",
+        "count data",
+        "high dimension",
+        "scRNAseq",
+        "PLN",
+    ],
     install_requires=requirements,
-    py_modules=['pyPLNmodels._utils','pyPLNmodels.elbos','pyPLNmodels.VEM','pyPLNmodels._closed_forms'],
-    long_description_content_type='text/markdown',
+    py_modules=[
+        "pyPLNmodels._utils",
+        "pyPLNmodels.elbos",
+        "pyPLNmodels.VEM",
+        "pyPLNmodels._closed_forms",
+    ],
+    long_description_content_type="text/markdown",
     license="MIT",
-         # See https://pypi.python.org/pypi?%3Aaction=list_classifiers
+    # See https://pypi.python.org/pypi?%3Aaction=list_classifiers
     classifiers=[
         # How mature is this project? Common values are
         #   3 - Alpha
@@ -40,6 +52,5 @@ setup(
         # Specify the Python versions you support here. In particular, ensure
         # that you indicate whether you support Python 2, Python 3 or both.
         "Programming Language :: Python :: 3 :: Only",
-        ],
-
+    ],
 )
diff --git a/test.py b/test.py
index 542455fc..d1e3f985 100644
--- a/test.py
+++ b/test.py
@@ -3,17 +3,21 @@ import torch
 from pyPLNmodels.VEM import ZIPLN, PLN, PLNPCA
 import numpy as np
 import seaborn as sns
-import matplotlib.pyplot as plt 
+import matplotlib.pyplot as plt
 
 Y = pd.read_csv("./example_data/test_data/Y_test.csv")
 covariates = pd.read_csv("./example_data/test_data/cov_test.csv")
-O = (pd.read_csv("./example_data/test_data/O_test.csv"))
-true_Sigma = torch.from_numpy(pd.read_csv("./example_data/test_data/true_parameters/true_Sigma_test.csv").values)
-true_beta = torch.from_numpy(pd.read_csv("./example_data/test_data/true_parameters/true_beta_test.csv").values)
+O = pd.read_csv("./example_data/test_data/O_test.csv")
+true_Sigma = torch.from_numpy(
+    pd.read_csv("./example_data/test_data/true_parameters/true_Sigma_test.csv").values
+)
+true_beta = torch.from_numpy(
+    pd.read_csv("./example_data/test_data/true_parameters/true_beta_test.csv").values
+)
 
 pln = PLN()
-pln.fit(Y, covariates, O,nb_max_iteration= 20)
+pln.fit(Y, covariates, O, nb_max_iteration=20)
 print(pln)
-pca = PLNPCA(q = 4)
+pca = PLNPCA(q=4)
 pca.fit(Y, covariates, O)
 print(pca)
-- 
GitLab