import numpy as np
import dask
import dask.array as dsa
from dask.base import tokenize, normalize_token
import xarray as xr
import warnings
from .duck_array_ops import concatenate
from .shrunk_index import all_index_data
from ..variables import dimensions
def _get_grid_metadata():
# keep this separate from get_var_metadata
# because grid stuff is weird
from ..mds_store import _get_all_grid_variables
from ..variables import extra_grid_variables, vertical_coordinates
# get grid info
grid_vars = _get_all_grid_variables('llc')
grid_vars.update(extra_grid_variables)
# make dictionary with keys as filenames
grid_metadata = {}
for key,val in grid_vars.items():
# masks use hFac filename to be computed in mds_store
if 'filename' in val and key[:4]!='mask':
val.update({'real_name':key})
grid_metadata[val['filename']] = val
else:
grid_metadata[key] = val
# force RF to point to Zp1, deal with this manually..
grid_metadata['RF']=vertical_coordinates['Zp1']
grid_metadata['RF']['real_name'] = 'Zp1'
for zv in ['Zu','Zl']:
grid_metadata[zv] = vertical_coordinates[zv]
return grid_metadata
def _get_var_metadata():
# The LLC run data comes with zero metadata. So we import metadata from
# the xmitgcm package.
from ..variables import state_variables, package_state_variables
from ..utils import parse_available_diagnostics
from ..default_diagnostics import diagnostics
from io import StringIO
diag_file = StringIO(diagnostics)
available_diags = parse_available_diagnostics(diag_file)
var_metadata = state_variables.copy()
var_metadata.update(package_state_variables)
var_metadata.update(available_diags)
# even the file names from the LLC data differ from standard MITgcm output
aliases = {'Eta': 'ETAN', 'PhiBot': 'PHIBOT', 'Salt': 'SALT',
'Theta': 'THETA'}
for a, b in aliases.items():
var_metadata[a] = var_metadata[b]
# add grid metadata
var_metadata.update(_get_grid_metadata())
return var_metadata
_VAR_METADATA = _get_var_metadata()
def _is_vgrid(vname):
# check for 1d, vertical grid variables
dims = _VAR_METADATA[vname]['dims']
return len(dims)==1 and dims[0][0]=='k'
def _get_variable_point(vname, mask_override):
# fix for https://github.com/MITgcm/xmitgcm/issues/191
if vname in mask_override:
return mask_override[vname]
dims = _VAR_METADATA[vname]['dims']
if 'i' in dims and 'j' in dims:
point = 'c'
elif 'i_g' in dims and 'j' in dims:
point = 'w'
elif 'i' in dims and 'j_g' in dims:
point = 's'
elif 'i_g' in dims and 'j_g' in dims:
raise ValueError("Don't have masks for corner points!")
else:
raise ValueError("Variable `%s` is not a horizontal variable." % vname)
return point
def _get_scalars_and_vectors(varnames, type):
for vname in varnames:
if vname not in _VAR_METADATA:
raise ValueError("Varname `%s` not found in metadata." % vname)
if type != 'latlon':
return varnames, []
scalars = []
vector_pairs = []
for vname in varnames:
meta = _VAR_METADATA[vname]
try:
mate = meta['attrs']['mate']
if mate not in varnames:
raise ValueError("Vector pairs are required to create "
"latlon type datasets. Varname `%s` is "
"missing its vector mate `%s`"
% vname, mate)
vector_pairs.append((vname, mate))
varnames.remove(mate)
except KeyError:
scalars.append(vname)
def _decompress(data, mask, dtype):
data_blank = np.full_like(mask, np.nan, dtype=dtype)
data_blank[mask] = data
data_blank.shape = mask.shape
return data_blank
[docs]def _pad_facet(data,facet_shape,reshape,pad_before,pad_after,dtype):
"""add padding to facet data that are irregularly shaped, returning
data of size ``facet_shape`` to make equally sized faces.
See
https://xmitgcm.readthedocs.io/en/latest/llcreader.html#aste-release-1-example
for an example
Parameters
----------
data : array like
data to be reshaped
facet_shape : tuple
"expected" facet shape determined by the _facet_shape function
reshape : bool
whether to reshape each face, True if facet is "rotated".
This determines the axis along which pad_before or pad_after refers to
pad_before, pad_after : int
size of padding added to either the i or j dimension where
before vs after determines ordering of: data-then-pad or pad-then-data
dtype : numpy.dtype
Datatype of the data
Returns
-------
padded_data : array like
which has shape = facet_shape, with nan's for padding
"""
pre_shape=list(facet_shape)
pad_shape=list(facet_shape)
if reshape:
concat_axis = -1
pad_shape[concat_axis] = pad_after
pre_shape[concat_axis] -= pad_after
padded= [data,np.full(pad_shape,np.nan,dtype=dtype)]
else:
concat_axis = -2
pad_shape[concat_axis] = pad_before
pre_shape[concat_axis] -= pad_before
padded = [np.full(pad_shape,np.nan,dtype=dtype),data]
data.shape = pre_shape
return concatenate(padded,axis=concat_axis)
def _facet_strides(nfaces):
if nfaces == 13:
return ((0,3), (3,6), (6,7), (7,10), (10,13))
elif nfaces == 6:
return ((0,2), (2,2), (2,3), (3,4), (4,6))
else:
raise TypeError(f'Unexpected nfaces {nfaces} for _facet_strides')
# whether to reshape each face
_facet_reshape = (False, False, False, True, True)
_nfacets = 5
def _uncompressed_facet_index(nfacet, nside, nfaces):
face_size = nside**2
start = _facet_strides(nfaces)[nfacet][0] * face_size
end = _facet_strides(nfaces)[nfacet][1] * face_size
return start, end
def _facet_shape(nfacet, nside, nfaces):
facet_length = _facet_strides(nfaces)[nfacet][1] - _facet_strides(nfaces)[nfacet][0]
if _facet_reshape[nfacet]:
facet_shape = (1, nside, facet_length*nside)
else:
facet_shape = (1, facet_length*nside, nside)
return facet_shape
def _facet_to_faces(data, nfacet, nfaces):
shape = data.shape
# facet dimension
nf, ny, nx = shape[-3:]
other_dims = shape[:-3]
assert nf == 1
facet_length = _facet_strides(nfaces)[nfacet][1] - _facet_strides(nfaces)[nfacet][0]
if _facet_reshape[nfacet]:
new_shape = other_dims + (ny, facet_length, nx / facet_length) if facet_length > 0 else 0
data_rs = data.reshape(new_shape)
data_rs = np.moveaxis(data_rs, -2, -3) # dask-safe
else:
new_shape = other_dims + (facet_length, ny / facet_length, nx) if facet_length > 0 else 0
data_rs = data.reshape(new_shape) if facet_length>0 else None
return data_rs
def _facets_to_faces(facets, nfaces):
all_faces = []
for nfacet, data_facet in enumerate(facets):
data_rs = _facet_to_faces(data_facet, nfacet, nfaces)
if data_rs is not None:
all_faces.append(data_rs)
return concatenate(all_faces, axis=-3)
def _faces_to_facets(data, nfaces, facedim=-3):
assert data.shape[facedim] == nfaces
facets = []
for nfacet, (strides, reshape) in enumerate(zip(_facet_strides(nfaces), _facet_reshape)):
face_data = [data[(...,) + (slice(nface, nface+1), slice(None), slice(None))]
for nface in range(*strides)]
if reshape:
concat_axis = facedim + 2
else:
concat_axis = facedim + 1
# todo: use duck typing for concat
facet_data = concatenate(face_data, axis=concat_axis) if len(face_data)!=0 else np.array([])
facets.append(facet_data)
return facets
def _rotate_scalar_facet(facet):
facet_transposed = np.moveaxis(facet, -1, -2)
facet_rotated = np.flip(facet_transposed, -2)
return facet_rotated
def _facets_to_latlon_scalar(all_facets):
rotated = (all_facets[:2]
+ [_rotate_scalar_facet(facet) for facet in all_facets[-2:]])
# drop facet dimension
rotated = [r[..., 0, :, :] for r in rotated]
return concatenate(rotated, axis=-1)
def _faces_to_latlon_scalar(data, nfaces):
data_facets = _faces_to_facets(data, nfaces)
return _facets_to_latlon_scalar(data_facets)
# dask's pad function doesn't work
# it does weird things to non-pad dimensions
# need to roll our own
def shift_and_pad(a,left=True):
if left:
a_shifted = a[..., 1:]
pad_array = dsa.zeros_like(a[..., -2:-1])
return concatenate([a_shifted, pad_array], axis=-1)
else:
a_shifted = a[..., :-1]
pad_array = dsa.zeros_like(a[..., 0:1])
return concatenate([pad_array, a_shifted], axis=-1)
def transform_v_to_u(facet):
return _rotate_scalar_facet(facet)
def transform_u_to_v(facet, metric=False):
# "shift" u component by 1 pixel
pad_width = (facet.ndim - 1) * (None,) + ((1, 0),)
#facet_padded = dsa.pad(facet[..., 1:], pad_width, 'constant')
facet_padded = shift_and_pad(facet)
assert facet.shape == facet_padded.shape
facet_rotated = _rotate_scalar_facet(facet_padded)
if not metric:
facet_rotated = -facet_rotated
return facet_rotated
def _facets_to_latlon_vector(facets_u, facets_v, metric=False):
# need to pad the rotated v values
ndim = facets_u[0].ndim
# second-to-last axis is the one to pad, plus a facet axis
assert ndim >= 3
# drop facet dimension
facets_u_drop = [f[..., 0, :, :] for f in facets_u]
facets_v_drop = [f[..., 0, :, :] for f in facets_v]
u_rot = (facets_u_drop[:2]
+ [transform_v_to_u(facet) for facet in facets_v_drop[-2:]])
v_rot = (facets_v_drop[:2]
+ [transform_u_to_v(facet, metric) for facet in facets_u_drop[-2:]])
u = concatenate(u_rot, axis=-1)
v = concatenate(v_rot, axis=-1)
return u, v
def _faces_to_latlon_vector(u_faces, v_faces, nfaces, metric=False):
u_facets = _faces_to_facets(u_faces, nfaces)
v_facets = _faces_to_facets(v_faces, nfaces)
u, v = _facets_to_latlon_vector(u_facets, v_facets, metric=metric)
return u, v
def _drop_facedim(dims):
dims = list(dims)
dims.remove('face')
return dims
def _add_face_to_dims(dims):
new_dims = dims.copy()
if 'j' in dims:
j_dim = dims.index('j')
new_dims.insert(j_dim, 'face')
elif 'j_g' in dims:
j_dim = dims.index('j_g')
new_dims.insert(j_dim, 'face')
return new_dims
def _faces_coords_to_latlon(ds):
coords = ds.reset_coords().coords.to_dataset()
ifac = 4
jfac = 3
dim_coords = {}
for vname in coords.coords:
if vname[0] == 'i':
data = np.arange(ifac * coords.dims[vname])
elif vname[0] == 'j':
data = np.arange(jfac * coords.dims[vname])
else:
data = coords[vname].data
var = xr.Variable(ds[vname].dims, data, ds[vname].attrs)
dim_coords[vname] = var
return xr.Dataset(dim_coords)
[docs]def faces_dataset_to_latlon(ds, metric_vector_pairs=[('dxC', 'dyC'), ('dyG', 'dxG')]):
"""Transform a 13-face LLC xarray Dataset into a rectancular grid,
discarding the Arctic.
Parameters
----------
ds : xarray.Dataset
A 13-face LLC dataset
metric_vector_pairs : list, optional
Pairs of variables that are positive-definite metrics located at grid
edges.
Returns
-------
out : xarray.Dataset
Transformed rectangular dataset
"""
coord_vars = list(ds.coords)
ds_new = _faces_coords_to_latlon(ds)
vector_pairs = []
scalars = []
vnames = list(ds.reset_coords().variables)
for vname in vnames:
try:
mate = ds[vname].attrs['mate']
except KeyError:
mate = None
# Raises an exception if the mate of a variable in vnames is missing.
if mate is not None:
vector_pairs.append((vname, mate))
try:
vnames.remove(mate)
except ValueError:
msg = 'If {} in varnames, {} must also be in varnames'.format(vname, mate)
raise ValueError(msg)
all_vector_components = [inner for outer in (vector_pairs + metric_vector_pairs)
for inner in outer]
scalars = [vname for vname in vnames if vname not in all_vector_components]
data_vars = {}
for vname in scalars:
if vname=='face' or vname in ds_new:
continue
if 'face' in ds[vname].dims:
data = _faces_to_latlon_scalar(ds[vname].data,nfaces=len(ds['face']))
dims = _drop_facedim(ds[vname].dims)
else:
data = ds[vname].data
dims = ds[vname].dims
data_vars[vname] = xr.Variable(dims, data, ds[vname].attrs)
for vname_u, vname_v in vector_pairs:
data_u, data_v = _faces_to_latlon_vector(ds[vname_u].data, ds[vname_v].data,
nfaces=len(ds['face']))
data_vars[vname_u] = xr.Variable(_drop_facedim(ds[vname_u].dims), data_u, ds[vname_u].attrs)
data_vars[vname_v] = xr.Variable(_drop_facedim(ds[vname_v].dims), data_v, ds[vname_v].attrs)
for vname_u, vname_v in metric_vector_pairs:
data_u, data_v = _faces_to_latlon_vector(ds[vname_u].data, ds[vname_v].data, nfaces=len(ds['face']), metric=True)
data_vars[vname_u] = xr.Variable(_drop_facedim(ds[vname_u].dims), data_u, ds[vname_u].attrs)
data_vars[vname_v] = xr.Variable(_drop_facedim(ds[vname_v].dims), data_v, ds[vname_v].attrs)
ds_new = ds_new.update(data_vars)
ds_new = ds_new.set_coords([c for c in coord_vars if c in ds_new])
return ds_new
# below are data transformers
def _all_facets_to_faces(data_facets, meta, nfaces):
return {vname: _facets_to_faces(data, nfaces)
for vname, data in data_facets.items()}
def _all_facets_to_latlon(data_facets, meta, nfaces=None):
vector_pairs = []
scalars = []
vnames = list(data_facets)
for vname in vnames:
try:
mate = meta[vname]['attrs']['mate']
except KeyError:
mate = None
# Raises an exception if the mate of a variable in vnames is missing.
if mate is not None:
vector_pairs.append((vname, mate))
try:
vnames.remove(mate)
except ValueError:
msg = 'If {} in varnames, {} must also be in varnames'.format(vname, mate)
raise ValueError(msg)
all_vector_components = [inner for outer in vector_pairs for inner in outer]
scalars = [vname for vname in vnames if vname not in all_vector_components]
data = {}
for vname in scalars:
data[vname] = _facets_to_latlon_scalar(data_facets[vname])
for vname_u, vname_v in vector_pairs:
data_u, data_v = _facets_to_latlon_vector(data_facets[vname_u],
data_facets[vname_v])
data[vname_u] = data_u
data[vname_v] = data_v
return data
def _chunks(l, n):
"""Yield successive n-sized chunks from l."""
for i in range(0, len(l), n):
yield l[i:i + n]
def _get_facet_chunk(store, varname, iternum, nfacet, klevels, nx, nz, nfaces,
dtype, mask_override, domain, pad_before, pad_after):
fs, path = store.get_fs_and_full_path(varname, iternum)
assert (nfacet >= 0) & (nfacet < _nfacets)
file = fs.open(path)
# insert singleton axis for time (if not grid var) and k level
facet_shape = (1,) + _facet_shape(nfacet, nx, nfaces)
facet_shape = (1,) + facet_shape if iternum is not None else facet_shape
level_data = []
if (store.shrunk and iternum is not None) or \
(store.shrunk_grid and iternum is None):
# the store tells us whether we need a mask or not
point = _get_variable_point(varname, mask_override)
mykey = nx if domain == 'global' else f'{domain}_{nx}'
index = all_index_data[mykey][point]
zgroup = store.open_mask_group()
mask = zgroup['mask_' + point].astype('bool')
else:
index = None
mask = None
# Offset start/end read position due to padding facet before me
pre_pad = np.cumsum([x+y for x,y in zip(pad_before,pad_after)])
pre_pad = shift_and_pad(pre_pad,left=False).compute()
tot_pad = pre_pad[-1]+pad_after[-1]
for k in klevels:
assert (k >= 0) & (k < nz)
# figure out where in the file we have to read to get the data
# for this level and facet
if index:
i = np.ravel_multi_index((k, nfacet), (nz, _nfacets))
start = index[i]
end = index[i+1]
else:
level_start = k * (nx**2 * nfaces - nx * tot_pad)
facet_start, facet_end = _uncompressed_facet_index(nfacet, nx, nfaces)
start = level_start + facet_start
end = level_start + facet_end - nx*pad_after[nfacet]
end = end - nx*(pad_before[nfacet]) if k*nfacet==0 else end
start,end = [x - nx*pre_pad[nfacet] if k+nfacet!=0 else x for x in [start,end]]
read_offset = start * dtype.itemsize # in bytes
read_length = (end - start) * dtype.itemsize # in bytes
file.seek(read_offset)
buffer = file.read(read_length)
data = np.frombuffer(buffer, dtype=dtype)
assert len(data) == (end - start)
if mask:
mask_level = mask[k]
mask_facets = _faces_to_facets(mask_level,nfaces)
this_mask = mask_facets[nfacet]
data = _decompress(data, this_mask, dtype)
elif pad_before[nfacet]+pad_after[nfacet]>0:
# Extra care for pad after with rotated fields
data = _pad_facet(data,facet_shape,_facet_reshape[nfacet],
pad_before[nfacet],pad_after[nfacet],dtype)
# this is the shape this facet is supposed to have
data.shape = facet_shape
level_data.append(data)
out = np.concatenate(level_data, axis=-4)
return out
def _get_1d_chunk(store, varname, klevels, nz, dtype):
"""for 1D vertical grid variables"""
fs, path = store.get_fs_and_full_path(varname, None)
file = fs.open(path)
# read all levels for 1D variables
read_length = nz*dtype.itemsize # all levels in bytes
buffer = file.read(read_length)
data = np.frombuffer(buffer,dtype=dtype)
return data[klevels]
[docs]class BaseLLCModel:
"""Class representing an LLC Model Dataset.
Parameters
----------
store : llcreader.BaseStore
The store object where the data can be found
mask_ds : zarr.Group
Must contain variables `mask_c`, `masc_w`, `mask_s`
Attributes
----------
dtype : numpy.dtype
Datatype of the data in the dataset
nx : int
Number of gridpoints per face (e.g. 90, 1080, 4320, etc.)
nz : int
Number of vertical gridpoints
delta_t : float
Numerical timestep
time_units : str
Date unit string, e.g 'seconds since 1948-01-01 12:00:00'
iter_start : int
First model iteration number (inclusive; follows python range conventions)
iter_stop : int
Final model iteration number (exclusive; follows python range conventions)
iter_step : int
Spacing between iterations
iters : list of ints
Specific iteration numbers in a list, possibly with nonuniform spacing.
Either provide this or the iter parameters above.
varnames, grid_varnames : list
List of data variable and grid variable names contained in the dataset
mask_override : dict
Override inference of masking variable, e.g. ``{'oceTAUX': 'c'}``
"""
nface = 13
dtype = np.dtype('>f4')
# should be implemented by child classes
nx = None
nz = None
delta_t = None
time_units = None
iter_start = None
iter_stop = None
iter_step = None
iters = None
varnames = []
grid_varnames = []
mask_override = {}
domain = 'global'
pad_before = [0]*_nfacets
pad_after = [0]*_nfacets
def __init__(self, store):
"""Initialize model
Parameters
----------
store : llcreader.BaseStore
mask_ds : zarr.Group
Must contain variables `mask_c`, `mask_w`, `mask_s`
"""
self.store = store
self.shape = (self.nz, self.nface, self.nx, self.nx)
if self.store.shrunk:
self.masks = self._get_masks()
from .shrunk_index import all_index_data
mykey = self.nx if self.domain == 'global' else f'{self.domain}_{self.nx}'
self.indexes = all_index_data[mykey]
else:
self.masks = None
self.indexes = None
def _get_masks(self):
masks = {}
zgroup = self.store.open_mask_group()
for point in ['c', 'w', 's']:
mask_faces = dsa.from_zarr(zgroup['mask_' + point]).astype('bool')
masks[point] = _faces_to_facets(mask_faces,self.nface)
return masks
def _dtype(self,varname=None):
if isinstance(self.dtype,np.dtype):
return self.dtype
elif isinstance(self.dtype,dict):
return np.dtype(self.dtype[varname])
def _get_kp1_levels(self,k_levels):
# determine kp1 levels
# get borders to all k (center) levels
# ki used to get Zu, Zl later
ku = k_levels[1:] + [k_levels[-1] + 1 ]
kp1 = []
ki=[]
for i,(x,y) in enumerate(zip(k_levels,ku)):
kp1 += [x] if x not in kp1 else []
kp1 += [y] if y-x==1 else [x+1]
kp1=np.array(kp1)
return kp1
def _make_coords_faces(self, all_iters):
time = self.delta_t * all_iters
time_attrs = {'units': self.time_units,
'calendar': self.calendar}
coords = {'face': ('face', np.arange(self.nface)),
'i': ('i', np.arange(self.nx)),
'i_g': ('i_g', np.arange(self.nx)),
'j': ('j', np.arange(self.nx)),
'j_g': ('j_g', np.arange(self.nx)),
'k': ('k', np.arange(self.nz)),
'k_u': ('k_u', np.arange(self.nz)),
'k_l': ('k_l', np.arange(self.nz)),
'k_p1': ('k_p1', np.arange(self.nz + 1)),
'niter': ('time', all_iters),
'time': ('time', time, time_attrs)
}
ds = xr.decode_cf(xr.Dataset(coords=coords))
from ..variables import dimensions
for d in dimensions:
if d in ds:
ds[d].attrs.update(dimensions[d]['attrs'])
return ds
def _make_coords_latlon():
ds = self._make_coords_faces(self)
return _faces_coords_to_latlon(ds)
def _dask_array(self, nfacet, varname, iters, klevels, k_chunksize):
# return a dask array for a single facet
facet_shape = _facet_shape(nfacet, self.nx, self.nface)
time_chunks = (len(iters) * (1,),) if iters is not None else ()
k_chunks = (tuple([len(c)
for c in _chunks(klevels, k_chunksize)]),)
chunks = time_chunks + k_chunks + tuple([(s,) for s in facet_shape])
# manually build dask graph
dsk = {}
token = tokenize(varname, self.store, nfacet)
name = '-'.join([varname, token])
dtype = self._dtype(varname)
# iters == None for grid variables
def _key_and_task(n_k, these_klevels, n_iter=None, iternum=None):
if n_iter is None:
key = name, n_k, 0, 0, 0
else:
key = name, n_iter, n_k, 0, 0, 0
task = (_get_facet_chunk, self.store, varname, iternum,
nfacet, these_klevels, self.nx, self.nz, self.nface,
dtype, self.mask_override, self.domain,
self.pad_before, self.pad_after)
return key, task
if iters is not None:
for n_iter, iternum in enumerate(iters):
for n_k, these_klevels in enumerate(_chunks(klevels, k_chunksize)):
key, task = _key_and_task(n_k, these_klevels, n_iter, iternum)
dsk[key] = task
else:
for n_k, these_klevels in enumerate(_chunks(klevels, k_chunksize)):
key, task = _key_and_task(n_k, these_klevels)
dsk[key] = task
return dsa.Array(dsk, name, chunks, dtype)
def _dask_array_vgrid(self, varname, klevels, k_chunksize):
# return a dask array for a 1D vertical grid var
# single chunk for 1D variables
chunks = ((len(klevels),),)
# manually build dask graph
dsk = {}
token = tokenize(varname, self.store)
name = '-'.join([varname, token])
dtype = self._dtype(varname)
nz = self.nz if _VAR_METADATA[varname]['dims'] != ['k_p1'] else self.nz+1
task = (_get_1d_chunk, self.store, varname,
list(klevels), nz, dtype)
key = name, 0
dsk[key] = task
return dsa.Array(dsk, name, chunks, dtype)
def _get_facet_data(self, varname, iters, klevels, k_chunksize):
# needs facets to be outer index of nested lists
dims = _VAR_METADATA[varname]['dims']
if len(dims)==2:
klevels = [0,]
if _is_vgrid(varname):
data_facets = self._dask_array_vgrid(varname,klevels,k_chunksize)
else:
data_facets = [self._dask_array(nfacet, varname, iters, klevels, k_chunksize)
for nfacet in range(5)]
if len(dims)==2:
# squeeze depth dimension out of 2D variable
data_facets = [facet[..., 0, :, :, :] for facet in data_facets]
return data_facets
def _check_iter_start(self, iter_start):
if self.iter_start is not None and self.iter_step is not None:
if (iter_start - self.iter_start) % self.iter_step:
msg = "Iteration {} may not exist, you may need to change 'iter_start'".format(iter_start)
warnings.warn(msg, RuntimeWarning)
def _check_iter_step(self, iter_step):
if self.iter_step is not None:
if iter_step % self.iter_step:
msg = "'iter_step' is not a multiple of {}, meaning some expected timesteps may not be returned".format(self.iter_step)
warnings.warn(msg, RuntimeWarning)
def _check_iters(self, iters):
if self.iters is not None:
if not set(iters) <= set(self.iters):
msg = "Some requested iterations may not exist, you may need to change 'iters'"
warnings.warn(msg, RuntimeWarning)
elif self.iter_start is not None and self.iter_step is not None:
for iter in iters:
if (iter - self.iter_start) % self.iter_step:
msg = "Some requested iterations may not exist, you may need to change 'iters'"
warnings.warn(msg, RuntimeWarning)
break
[docs] def get_dataset(self, varnames=None, iter_start=None, iter_stop=None,
iter_step=None, iters=None, k_levels=None, k_chunksize=1,
type='faces', read_grid=True, grid_vars_to_coords=True):
"""
Create an xarray Dataset object for this model.
Parameters
----------
*varnames : list of strings, optional
The variables to include, e.g. ``['Salt', 'Theta']``. Otherwise
include all known variables.
iter_start : int, optional
Starting iteration number. Otherwise use model default.
Follows standard `range` conventions. (inclusive)
iter_stop : int, optional
Stopping iteration number. Otherwise use model default.
Follows standard `range` conventions. (exclusive)
iter_step : int, optional
Iteration number stepsize. Otherwise use model default.
iters : list of ints, optional
Specific iteration numbers in a list, possibly with nonuniform spacing.
Either provide this or the iter parameters above.
k_levels : list of ints, optional
Vertical levels to extract. Default is to get them all
k_chunksize : int, optional
How many vertical levels per Dask chunk.
type : {'faces', 'latlon'}, optional
What type of dataset to create
read_grid : bool, optional
Whether to read the grid info
grid_vars_to_coords : bool, optional
Whether to promote grid variables to coordinate status
Returns
-------
ds : xarray.Dataset
"""
def _if_not_none(a, b):
if a is None:
return b
else:
return a
user_iter_params = [iter_start, iter_stop, iter_step]
attribute_iter_params = [self.iter_start, self.iter_stop, self.iter_step]
# If the user has specified some iter params:
if any([a is not None for a in user_iter_params]):
# If iters is also set we have a problem
if iters is not None:
raise ValueError("Only `iters` or the parameters `iter_start`, `iters_stop`, "
"and `iter_step` can be provided. Both were provided")
# Otherwise we can override any missing values
iter_start = _if_not_none(iter_start, self.iter_start)
iter_stop = _if_not_none(iter_stop, self.iter_stop)
iter_step = _if_not_none(iter_step, self.iter_step)
iter_params = [iter_start, iter_stop, iter_step]
if any([a is None for a in iter_params]):
raise ValueError("The parameters `iter_start`, `iter_stop`, "
"and `iter_step` must be defined either by the "
"model class or as argument. Instead got %r "
% iter_params)
# Otherwise try loading from the user set iters
elif iters is not None:
pass
# Now have a go at using the attribute derived iteration parameters
elif all([a is not None for a in attribute_iter_params]):
iter_params = attribute_iter_params
# Now try using the attribute derived iters
elif self.iters is not None:
iters = self.iters
# Now give up
else:
raise ValueError("The parameters `iter_start`, `iter_stop`, "
"and `iter_step`, or `iters` must be defined either by the "
"model class or as argument")
# Check the iter_start and iter_step
if iters is None:
self._check_iter_start(iter_params[0])
self._check_iter_step(iter_params[2])
iters = np.arange(*iter_params)
else:
self._check_iters(iters)
iters = np.array(iters)
varnames = varnames or self.varnames
# grid stuff
read_grid = read_grid and len(self.grid_varnames)!=0
if read_grid and self.store.grid_path is None:
raise TypeError('Cannot read grid if grid_path is not specified in filestore (e.g. llcreader.known_models)')
grid_vars_to_coords = grid_vars_to_coords and read_grid
grid_varnames = self.grid_varnames if read_grid else []
ds = self._make_coords_faces(iters)
if type=='latlon':
if self.domain=='aste':
raise TypeError('Swapping to lat/lon not available for ASTE. Must regrid or interpolate.')
ds = _faces_coords_to_latlon(ds)
k_levels = k_levels or list(range(self.nz))
kp1_levels = self._get_kp1_levels(k_levels)
ds = ds.sel(k=k_levels, k_l=k_levels, k_u=k_levels, k_p1=kp1_levels)
# get the data in facet form
data_facets = {vname:
self._get_facet_data(vname, iters, k_levels, k_chunksize)
for vname in varnames}
# get the grid in facet form
# do separately for vertical coords on kp1_levels
grid_facets = {}
for vname in grid_varnames:
my_k_levels = k_levels if _VAR_METADATA[vname]['dims'] !=['k_p1'] else kp1_levels
grid_facets[vname] = self._get_facet_data(vname, None, my_k_levels, k_chunksize)
# transform it into faces or latlon
data_transformers = {'faces': _all_facets_to_faces,
'latlon': _all_facets_to_latlon}
transformer = data_transformers[type]
data = transformer(data_facets, _VAR_METADATA, self.nface)
# separate horizontal and vertical grid variables
hgrid_facets = {key: grid_facets[key]
for key in grid_varnames if not _is_vgrid(key)}
vgrid_facets = {key: grid_facets[key]
for key in grid_varnames if _is_vgrid(key)}
# do not transform vertical grid variables
data.update(transformer(hgrid_facets, _VAR_METADATA, self.nface))
data.update(vgrid_facets)
variables = {}
gridlist = ['Zl','Zu'] if read_grid else []
for vname in varnames+grid_varnames:
meta = _VAR_METADATA[vname]
dims = meta['dims']
if type=='faces':
dims = _add_face_to_dims(dims)
dims = ['time',] + dims if vname not in grid_varnames else dims
attrs = meta['attrs']
# Handle grid names different from filenames
fname = vname
vname = meta['real_name'] if 'real_name' in meta else vname
if fname in grid_varnames:
gridlist.append(vname)
variables[vname] = xr.Variable(dims, data[fname], attrs)
# handle vertical coordinate after the fact
if read_grid and 'RF' in grid_varnames:
ki = np.array([list(kp1_levels).index(x) for x in k_levels])
for zv,sl in zip(['Zl','Zu'],[ki,ki+1]):
variables[zv] = xr.Variable(_VAR_METADATA[zv]['dims'],
data['RF'][sl],
_VAR_METADATA[zv]['attrs'])
ds = ds.update(variables)
if grid_vars_to_coords:
ds = ds.set_coords(gridlist)
return ds