From 31ac638d4ede12ade89fd2cd3b71c118d054b1e4 Mon Sep 17 00:00:00 2001 From: Vincent Delbar <vincent.delbar@latelescop.fr> Date: Mon, 25 Apr 2022 18:39:31 +0200 Subject: [PATCH 1/2] FIX: tfrecords dtype is forced to float32 --- otbtf/utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/otbtf/utils.py b/otbtf/utils.py index 729f7715..069638a5 100644 --- a/otbtf/utils.py +++ b/otbtf/utils.py @@ -38,12 +38,13 @@ def gdal_open(filename): return gdal_ds -def read_as_np_arr(gdal_ds, as_patches=True): +def read_as_np_arr(gdal_ds, as_patches=True, dtype=None): """ Read a GDAL raster as numpy array :param gdal_ds: a GDAL dataset instance :param as_patches: if True, the returned numpy array has the following shape (n, psz_x, psz_x, nb_channels). If False, the shape is (1, psz_y, psz_x, nb_channels) + :param dtype: if not None array dtype will be cast to given numpy data type (np.float32, np.uint16...) :return: Numpy array of dim 4 """ buffer = gdal_ds.ReadAsArray() @@ -56,4 +57,9 @@ def read_as_np_arr(gdal_ds, as_patches=True): else: n_elems = 1 size_y = gdal_ds.RasterYSize - return np.float32(buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount))) + + buffer = buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount)) + if dtype is not None: + buffer = buffer.astype(dtype) + + return buffer -- GitLab From 60851101f4d2f80cad31df9bc160e81572d70f0d Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@irstea.fr> Date: Mon, 9 May 2022 11:21:06 +0200 Subject: [PATCH 2/2] CI: test SR4RS fix --- .gitlab-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 87964c83..4f302884 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -118,7 +118,7 @@ sr4rs: - wget -O sr4rs_data.zip https://nextcloud.inrae.fr/s/kDms9JrRMQE2Q5z/download - unzip -o sr4rs_data.zip - rm -rf sr4rs - - git clone https://github.com/remicres/sr4rs.git + - git clone -b 44-cast_float_input https://github.com/remicres/sr4rs.git - export PYTHONPATH=$PYTHONPATH:$PWD/sr4rs - python -m pytest --junitxml=$ARTIFACT_TEST_DIR/report_sr4rs.xml $OTBTF_SRC/test/sr4rs_unittest.py -- GitLab