diff --git a/common/column_store.py b/common/column_store.py index cb72076eea..14d7accc3d 100644 --- a/common/column_store.py +++ b/common/column_store.py @@ -158,7 +158,7 @@ class ColumnStoreWriter(): # TODO(mgraczyk): This implementation will need to change if we add zip or compression. return ColumnStoreWriter(os.path.join(self._path, group_name)) - def add_dict(self, data, dtype=None, compression=False, overwrite=False): + def add_dict(self, data, dtypes=None, compression=False, overwrite=False): # default name exists to have backward compatibility with equivalent directory structure npy_path = os.path.join(self._path, "columnstore") mkdirs_exists_ok(os.path.dirname(npy_path)) @@ -166,6 +166,7 @@ class ColumnStoreWriter(): flat_dict = dict() _flatten_dict(flat_dict, "", data) for k, v in flat_dict.items(): + dtype = dtypes[k] if dtypes is not None and k in dtypes else None flat_dict[k] = np.array(v, copy=False, dtype=dtype) if overwrite: