Skip to content
Snippets Groups Projects

Adapt TFRecords data type to the initial data type

Merged Cresson Remi requested to merge 21-tfrecords-data-type into modifs
All threads resolved!
1 file
+ 27
12
Compare changes
  • Side-by-side
  • Inline
+ 27
12
@@ -34,6 +34,19 @@ import tensorflow as tf
from osgeo import gdal
from tqdm import tqdm
# --------------------------------------------- GDAL to numpy types ----------------------------------------------------
GDAL_TO_NP_TYPES = {1: 'uint8',
2: 'uint16',
3: 'int16',
4: 'uint32',
5: 'int32',
6: 'float32',
7: 'float64',
10: 'complex64',
11: 'complex128'}
# ----------------------------------------------------- Helpers --------------------------------------------------------
@@ -58,8 +71,9 @@ def read_as_np_arr(gdal_ds, as_patches=True):
False, the shape is (1, psz_y, psz_x, nb_channels)
:return: Numpy array of dim 4
"""
buffer = gdal_ds.ReadAsArray()
gdal_type = gdal_ds.GetRasterBand(1).DataType
size_x = gdal_ds.RasterXSize
buffer = gdal_ds.ReadAsArray().astype(GDAL_TO_NP_TYPES[gdal_type])
if len(buffer.shape) == 3:
buffer = np.transpose(buffer, axes=(1, 2, 0))
if not as_patches:
@@ -68,7 +82,7 @@ def read_as_np_arr(gdal_ds, as_patches=True):
else:
n_elems = int(gdal_ds.RasterYSize / size_x)
size_y = size_x
return np.float32(buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount)))
return buffer.reshape((n_elems, size_y, size_x, gdal_ds.RasterCount))
# -------------------------------------------------- Buffer class ------------------------------------------------------
@@ -246,6 +260,7 @@ class PatchesImagesReader(PatchesReaderBase):
def _read_extract_as_np_arr(gdal_ds, offset):
assert gdal_ds is not None
psz = gdal_ds.RasterXSize
gdal_type = gdal_ds.GetRasterBand(1).DataType
yoff = int(offset * psz)
assert yoff + psz <= gdal_ds.RasterYSize
buffer = gdal_ds.ReadAsArray(0, yoff, psz, psz)
@@ -254,7 +269,7 @@ class PatchesImagesReader(PatchesReaderBase):
else: # single-band raster
buffer = np.expand_dims(buffer, axis=2)
return np.float32(buffer)
return buffer.astype(GDAL_TO_NP_TYPES[gdal_type])
def get_sample(self, index):
"""
@@ -613,8 +628,8 @@ class TFRecords:
"""
data_converted = {}
for k, d in data.items():
data_converted[k] = d.name
for key, value in data.items():
data_converted[key] = value.name
return data_converted
@@ -629,7 +644,7 @@ class TFRecords:
filepath = os.path.join(self.dirpath, f"{i}.records")
with tf.io.TFRecordWriter(filepath) as writer:
for s in range(nb_sample):
for _ in range(nb_sample):
sample = dataset.read_one_sample()
serialized_sample = {name: tf.io.serialize_tensor(fea) for name, fea in sample.items()}
features = {name: self._bytes_feature(serialized_tensor) for name, serialized_tensor in
@@ -646,8 +661,8 @@ class TFRecords:
:param filepath: Output file name
"""
with open(filepath, 'w') as f:
json.dump(data, f, indent=4)
with open(filepath, 'w') as file:
json.dump(data, file, indent=4)
@staticmethod
def load(filepath):
@@ -655,8 +670,8 @@ class TFRecords:
Return data from pickle format.
:param filepath: Input file name
"""
with open(filepath, 'r') as f:
return json.load(f)
with open(filepath, 'r') as file:
return json.load(file)
def convert_dataset_output_shapes(self, dataset):
"""
@@ -665,8 +680,8 @@ class TFRecords:
"""
output_shapes = {}
for key in dataset.output_shapes.keys():
output_shapes[key] = (None,) + dataset.output_shapes[key]
for key, value in dataset.output_shapes.keys():
output_shapes[key] = (None,) + value
self.save(output_shapes, self.output_shape_file)
Loading