From 31ac638d4ede12ade89fd2cd3b71c118d054b1e4 Mon Sep 17 00:00:00 2001
From: Vincent Delbar <vincent.delbar@latelescop.fr>
Date: Mon, 25 Apr 2022 18:39:31 +0200
Subject: [PATCH 1/2] FIX: tfrecords dtype is forced to float32

---
 otbtf/utils.py | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/otbtf/utils.py b/otbtf/utils.py
index 729f7715..069638a5 100644
--- a/otbtf/utils.py
+++ b/otbtf/utils.py
@@ -38,12 +38,13 @@ def gdal_open(filename):
     return gdal_ds
 
 
-def read_as_np_arr(gdal_ds, as_patches=True):
+def read_as_np_arr(gdal_ds, as_patches=True, dtype=None):
     """
     Read a GDAL raster as numpy array
     :param gdal_ds: a GDAL dataset instance
     :param as_patches: if True, the returned numpy array has the following shape (n, psz_x, psz_x, nb_channels). If
         False, the shape is (1, psz_y, psz_x, nb_channels)
+    :param dtype: if not None array dtype will be cast to given numpy data type (np.float32, np.uint16...)
     :return: Numpy array of dim 4
     """
     buffer = gdal_ds.ReadAsArray()
@@ -56,4 +57,9 @@ def read_as_np_arr(gdal_ds, as_patches=True):
     else:
         n_elems = 1
         size_y = gdal_ds.RasterYSize
-    return np.float32(buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount)))
+
+    buffer = buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount))
+    if dtype is not None:
+        buffer = buffer.astype(dtype)
+
+    return buffer
-- 
GitLab


From 60851101f4d2f80cad31df9bc160e81572d70f0d Mon Sep 17 00:00:00 2001
From: Remi Cresson <remi.cresson@irstea.fr>
Date: Mon, 9 May 2022 11:21:06 +0200
Subject: [PATCH 2/2] CI: test SR4RS fix

---
 .gitlab-ci.yml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 87964c83..4f302884 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -118,7 +118,7 @@ sr4rs:
     - wget -O sr4rs_data.zip https://nextcloud.inrae.fr/s/kDms9JrRMQE2Q5z/download
     - unzip -o sr4rs_data.zip
     - rm -rf sr4rs
-    - git clone https://github.com/remicres/sr4rs.git
+    - git clone -b 44-cast_float_input https://github.com/remicres/sr4rs.git
     - export PYTHONPATH=$PYTHONPATH:$PWD/sr4rs
     - python -m pytest --junitxml=$ARTIFACT_TEST_DIR/report_sr4rs.xml $OTBTF_SRC/test/sr4rs_unittest.py
 
-- 
GitLab