Source code for temporaldata.data

from __future__ import annotations

import copy
from collections.abc import Mapping, Sequence
from typing import Any, Dict, List, Literal, Tuple, Union, Callable, Optional, Type
from pathlib import Path
import warnings

import h5py
import numpy as np

from .arraydict import ArrayDict, LazyArrayDict
from .irregular_ts import IrregularTimeSeries, LazyIrregularTimeSeries
from .regular_ts import RegularTimeSeries, LazyRegularTimeSeries
from .interval import Interval, LazyInterval
from .utils import _size_repr


[docs] class Data(object): r"""A flexible container for other data objects such as :obj:`ArrayDict`, :obj:`RegularTimeSeries`, :obj:`IrregularTimeSeries`, and :obj:`Interval` objects, as well as nested :obj:`Data` objects and regular Python objects like scalars, strings, and numpy arrays. Args: **kwargs: Arbitrary attributes to attach to the data object (e.g. spikes, lfp, units, trials, metadata). domain: An :obj:`Interval` specifying time domain of the data object. If ``"auto"``, the domain is computed as the union of the domains of any time-based attributes. Defaults to :obj:`None`. Example :: >>> import numpy as np >>> from temporaldata import ( ... ArrayDict, ... IrregularTimeSeries, ... RegularTimeSeries, ... Interval, ... Data, ... ) >>> data = Data( ... session_id="session_0", ... spikes=IrregularTimeSeries( ... timestamps=np.array([0.1, 0.2, 0.3, 2.1, 2.2, 2.3]), ... unit_index=np.array([0, 0, 1, 0, 1, 2]), ... waveforms=np.zeros((6, 48)), ... domain=Interval(0., 3.), ... ), ... lfp=RegularTimeSeries( ... raw=np.zeros((1000, 3)), ... sampling_rate=250., ... domain=Interval(0., 4.), ... ), ... units=ArrayDict( ... id=np.array(["unit_0", "unit_1", "unit_2"]), ... brain_region=np.array(["M1", "M1", "PMd"]), ... ), ... trials=Interval( ... start=np.array([0, 1, 2]), ... end=np.array([1, 2, 3]), ... go_cue_time=np.array([0.5, 1.5, 2.5]), ... drifting_gratings_dir=np.array([0, 45, 90]), ... ), ... drifting_gratings_imgs=np.zeros((8, 3, 32, 32)), ... domain=Interval(0., 4.), ... ) >>> data Data( session_id='session_0', spikes=IrregularTimeSeries( timestamps=[6], unit_index=[6], waveforms=[6, 48] ), lfp=RegularTimeSeries( raw=[1000, 3] ), units=ArrayDict( id=[3], brain_region=[3] ), trials=Interval( start=[3], end=[3], go_cue_time=[3], drifting_gratings_dir=[3] ), drifting_gratings_imgs=[8, 3, 32, 32], ) >>> data.slice(1, 3) Data( session_id='session_0', spikes=IrregularTimeSeries( timestamps=[3], unit_index=[3], waveforms=[3, 48] ), lfp=RegularTimeSeries( raw=[500, 3] ), units=ArrayDict( id=[3], brain_region=[3] ), trials=Interval( start=[2], end=[2], go_cue_time=[2], drifting_gratings_dir=[2] ), drifting_gratings_imgs=[8, 3, 32, 32], _absolute_start=1.0, ) """ _absolute_start = 0.0 _domain = None _file: Optional[h5py.File] = None def __init__( self, *, domain: Union[Interval, Literal["auto"], None] = None, **kwargs, ): if domain == "auto": # the domain is the union of the domains of the attributes domain = Interval(np.array([]), np.array([])) for key, value in kwargs.items(): if isinstance(value, (IrregularTimeSeries, RegularTimeSeries)): domain = domain | value.domain if isinstance(value, Interval): domain = domain | value if isinstance(value, Data) and value.domain is not None: domain = domain | value.domain if domain is not None and not isinstance(domain, Interval): raise ValueError("domain must be an Interval object.") self._domain = domain for key, value in kwargs.items(): setattr(self, key, value) # these variables will hold the original start and end times # and won't be modified when slicing # self.original_start = start # self.original_end = end # if any time-based attribute is present, start and end must be specified # todo check domain, also check when a new attribute is set def __setattr__(self, name, value): if name != "_domain" and ( ( isinstance(value, (IrregularTimeSeries, RegularTimeSeries, Interval)) and self.domain is None ) or ( isinstance(value, Data) and self.domain is None and value.domain is not None ) ): raise ValueError( f"Data object must have a domain if it contains a time-based attribute " f"({name})." ) super().__setattr__(name, value) def __getattr__(self, name) -> Any: raise AttributeError(f"Attribute {name} not found.") @property def domain(self): r"""Returns the domain of the data object.""" return self._domain @property def start(self): r"""Returns the start time of the data object.""" return self.domain.start[0] if self.domain is not None else None @property def end(self): r"""Returns the end time of the data object.""" return self.domain.end[-1] if self.domain is not None else None @property def absolute_start(self): r"""Returns the start time of this slice relative to the original start time. Should be 0. if the data object has not been sliced. Example :: >>> from temporaldata import Data >>> data = Data(domain=Interval(0., 4.)) >>> data.absolute_start 0.0 >>> data = data.slice(1, 3) >>> data.absolute_start 1.0 >>> data = data.slice(0.4, 1.4) >>> data.absolute_start 1.4 """ return self._absolute_start if self.domain is not None else None
[docs] def slice(self, start: float, end: float, reset_origin: bool = True): r"""Returns a new :obj:`Data` object that contains the data between the start and end times. This method will slice all time-based attributes that are present in the data object. Args: start: Start time. end: End time. reset_origin: If :obj:`True`, all time attributes will be updated to be relative to the new start time. Defaults to :obj:`True`. """ if self.domain is None: raise ValueError( "Data object does not contain any time-based attributes, " "and can thus not be sliced." ) out = self.__class__.__new__(self.__class__) for key, value in self.__dict__.items(): # todo update domain if key in ("_domain", "_file"): pass elif isinstance( value, (IrregularTimeSeries, RegularTimeSeries, Interval) ) or (isinstance(value, Data) and value.domain is not None): out.__dict__[key] = value.slice(start, end, reset_origin) else: out.__dict__[key] = copy.copy(value) # update domain out._domain = copy.copy(self._domain) & Interval(start, end) if reset_origin: out._domain.start -= start out._domain.end -= start # update slice start time out._absolute_start = self._absolute_start + start return out
[docs] def select_by_interval(self, interval: Interval): r"""Return a new :obj:`IrregularTimeSeries` object where all timestamps are within the interval. Args: interval: Interval object. """ if self.domain is None: raise ValueError( "Data object does not contain any time-based attributes, " "and can thus not be sliced." ) out = self.__class__.__new__(self.__class__) for key, value in self.__dict__.items(): # todo update domain if key in ("_domain", "_file"): pass elif isinstance( value, (IrregularTimeSeries, RegularTimeSeries, Interval) ) or (isinstance(value, Data) and value.domain is not None): if isinstance(value, RegularTimeSeries): value = value.to_irregular() out.__dict__[key] = value.select_by_interval(interval) else: out.__dict__[key] = copy.copy(value) out._domain = self._domain & interval return out
def __repr__(self) -> str: cls = self.__class__.__name__ info = "" for key, value in self.__dict__.items(): if key in ("_domain", "_file"): pass elif isinstance(value, ArrayDict): info = info + key + "=" + repr(value) + ",\n" elif value is not None: info = info + _size_repr(key, value) + ",\n" info = info.rstrip() return f"{cls}(\n{info}\n)"
[docs] def to_dict(self) -> Dict[str, Any]: r"""Returns a dictionary of stored key/value pairs.""" out_dict = {k: v for k, v in self.__dict__.items() if k != "_file"} return copy.deepcopy(out_dict)
[docs] def to_hdf5(self, file, serialize_fn_map=None): r"""Saves the data object to an HDF5 file. This method will also call the `to_hdf5` method of all contained data objects, so that the entire data object is saved to the HDF5 file, i.e. no need to call `to_hdf5` for each contained data object. Args: file (h5py.File): HDF5 file. .. code-block:: python import h5py from temporaldata import Data data = Data(...) with h5py.File("data.h5", "w") as f: data.to_hdf5(f) """ for key in self.keys(): value = getattr(self, key) if isinstance(value, (Data, ArrayDict)): grp = file.create_group(key) if isinstance(value, Data): value.to_hdf5(grp, serialize_fn_map=serialize_fn_map) else: value.to_hdf5(grp) elif isinstance(value, np.ndarray): # todo add warning if array is too large # recommend using ArrayDict file.create_dataset(key, data=value) elif value is not None: # each attribute should be small (generally < 64k) # there is no partial I/O; the entire attribute must be read value = _serialize(value, serialize_fn_map=serialize_fn_map) file.attrs[key] = value if self._domain is not None: grp = file.create_group("domain") self._domain.to_hdf5(grp) file.attrs["object"] = "Data" file.attrs["absolute_start"] = self._absolute_start
[docs] @classmethod def from_hdf5(cls, file: h5py.File | h5py.Group, lazy: bool = True): r"""Loads the data object from an HDF5 file. This method will also call the `from_hdf5` method of all contained data objects, so that the entire data object is loaded from the HDF5 file, i.e. no need to call `from_hdf5` for each contained data object. Args: file (h5py.File): HDF5 file. .. code-block:: python import h5py from temporaldata import Data with h5py.File("data.h5", "r") as f: data = Data.from_hdf5(f) """ # check that the file is read-only if isinstance(file, h5py.File): assert file.mode == "r", "File must be opened in read-only mode." data = {} for key, value in file.items(): if isinstance(value, h5py.Group): class_name = value.attrs["object"] if lazy and class_name != "Data": group_cls = globals()[f"Lazy{class_name}"] else: group_cls = globals()[class_name] if class_name == "Data": data[key] = group_cls.from_hdf5(value, lazy=lazy) else: data[key] = group_cls.from_hdf5(value) else: # if array, it will be loaded no matter what, always prefer ArrayDict data[key] = value[:] for key, value in file.attrs.items(): if key == "object" or key == "absolute_start": continue data[key] = value obj = cls(**data) # restore the absolute start time obj._absolute_start = file.attrs["absolute_start"] if lazy and isinstance(file, h5py.File): obj._file = file return obj
@property def file(self) -> h5py.File | None: r"""The underlying HDF5 file handle, or ``None`` if no file is open. Only set when the object was created via :meth:`load` or :meth:`from_hdf5` with ``lazy=True``.""" return self._file
[docs] @classmethod def load(cls, path: Union[Path, str], lazy: bool = True) -> Data: r"""Loads the :class:`Data` object from an HDF5 file given its file path. When ``lazy=True`` (default), the underlying HDF5 file remains open and data is loaded on demand. The caller is responsible for closing the file handle when done, either by calling :meth:`close` or by using the context manager protocol. When ``lazy=False``, all data is read into memory immediately and the file is closed before returning. Args: path: The file path to the HDF5 file containing the :class:`Data` object. lazy: If True (default), load contained objects in lazy mode (using LazyArrayDict, LazyRegularTimeSeries, etc.); if False, read all data immediately into memory. Returns: Data: The loaded :class:`Data` object from the HDF5 file. .. code-block:: python from temporaldata import Data # lazy with context manager (recommended) with Data.load("data.h5") as data: ... # lazy with explicit close data = Data.load("data.h5") ... data.close() # non-lazy (no close needed) data = Data.load("data.h5", lazy=False) """ file = h5py.File(path) try: obj = cls.from_hdf5(file, lazy=lazy) except Exception: file.close() raise if not lazy: file.close() return obj
[docs] def close(self, strict: bool = False): r"""Close the file-handle that was opened for lazy-loading. Any lazy attributes that have not been materialized will become invalid. Args: strict: If ``True``, raise an error when no open file handle is present. Default ``False``. """ if self._file is not None: self._file.close() self._file = None return if strict: raise RuntimeError("No file handle is open")
[docs] def save(self, path: Union[Path, str]): r"""Saves the data object to an HDF5 file at the given path. Args: path: Destination file path .. code-block:: python from temporaldata import Data data = Data(...) data.save("data.h5") """ with h5py.File(path, "w") as f: self.to_hdf5(f)
def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() return False
[docs] def set_train_domain(self, interval: Interval): """Deprecated no-op retained for backward compatibility.""" warnings.warn( "set_train_domain() is being deprecated and will be removed in a future version. " "Please directly set the train_domain attribute of this Data object.", DeprecationWarning, stacklevel=2, ) self.train_domain = interval
[docs] def set_valid_domain(self, interval: Interval): """Deprecated no-op retained for backward compatibility.""" warnings.warn( "set_valid_domain() is being deprecated and will be removed in a future version. " "Please directly set the valid_domain attribute of this Data object.", DeprecationWarning, stacklevel=2, ) self.valid_domain = interval
[docs] def set_test_domain(self, interval: Interval): """Deprecated no-op retained for backward compatibility.""" warnings.warn( "set_test_domain() is being deprecated and will be removed in a future version. " "Please directly set the test_domain attribute of this Data object.", DeprecationWarning, stacklevel=2, ) self.test_domain = interval
def _check_for_data_leakage(self, *args, **kwargs): """Deprecated no-op retained for backward compatibility. The ``_check_for_data_leakage`` method no longer performs any validation. Actual data leakage checks should be performed by the sampler. """ warnings.warn( "_check_for_data_leakage() is being deprecated and will be removed in a future version. ", DeprecationWarning, stacklevel=2, ) return True
[docs] def keys(self) -> List[str]: r"""Returns a list of all attribute names.""" return list(filter(lambda x: not x.startswith("_"), self.__dict__))
def __contains__(self, key: str) -> bool: r"""Returns :obj:`True` if the attribute :obj:`key` is present in the data.""" return key in self.keys()
[docs] def get_nested_attribute(self, path: str) -> Any: r"""Returns the attribute specified by the path. The path can be nested using dots. For example, if the path is "spikes.timestamps", this method will return the timestamps attribute of the spikes object. Args: path: Nested attribute path. """ # Split key by dots, resolve using getattr components = path.split(".") out = self for c in components: try: out = getattr(out, c) except AttributeError: raise AttributeError( f"Could not resolve {path} in data (specifically, at level {c})" ) return out
[docs] def set_nested_attribute(self, path: str, value: Any) -> Data: r"""Set a nested attribute specified by its path. The path can be nested using dots. For example, if the path is "session.id", this method will set the value of the ``id`` attribute of the ``session`` object. The attribute is modified in an in-place manner. Args: path: Nested attribute path (can be dot-separated, e.g. "session.id"). value: The value to set for the attribute. Returns: Data: self with the updated nested attribute. Raises: AttributeError: If any component of the path cannot be resolved. """ # Split key by dots, resolve using getattr components = path.split(".") obj = self for c in components[:-1]: try: obj = getattr(obj, c) except AttributeError: raise AttributeError( f"Could not resolve {path} in data (specifically, at level {c})" ) setattr(obj, components[-1], value) return self
[docs] def has_nested_attribute(self, path: str) -> bool: """Check if the attribute specified by the path exists in the Data object.""" if not path: return False current_obj = self attribute_names = path.split(".") for name in attribute_names: try: current_obj = current_obj.__dict__[name] except KeyError: return False return True
def __copy__(self): # create a shallow copy of the object # the full skeleton of the Data object, i.e. including all ArrayDict children, # will be copied. However, the data itself (np.ndarray, etc.) will not be # copied. cls = self.__class__ result = cls.__new__(cls) for k, v in self.__dict__.items(): if isinstance(v, ArrayDict): setattr(result, k, copy.copy(v)) else: setattr(result, k, v) return result def __deepcopy__(self, memo): # create a deep copy of the object # h5py objects will not be deepcopied, we only allow read-only access to the # HDF5 file, so this should not be an issue. cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): if isinstance(v, (h5py.Dataset, h5py.File)): # open files cannot be deepcopied setattr(result, k, v) else: setattr(result, k, copy.deepcopy(v, memo)) return result
[docs] def materialize(self) -> Data: r"""Materializes the data object, i.e., loads into memory all of the data that is still referenced in the HDF5 file.""" for key in self.keys(): # simply access all attributes to trigger the lazy loading if isinstance(getattr(self, key), (Data, ArrayDict)): getattr(self, key).materialize() if self.domain is not None: self.domain.materialize() return self
def _serialize( elem, serialize_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, ): r""" General serialization function that handles object types that are not supported by hdf5. The function also opens function registry to deal with specific element types through `serialize_fn_map`. This function will automatically be applied to elements in a nested sequence structure. Args: elem: a single element to be serialized. serialize_fn_map: Optional dictionary mapping from element type to the corresponding serialize function. If the element type isn't present in this dictionary, it will be skipped and the element will be returned as is. """ elem_type = type(elem) if serialize_fn_map is not None: if elem_type in serialize_fn_map: return serialize_fn_map[elem_type](elem, serialize_fn_map=serialize_fn_map) for object_type in serialize_fn_map: if isinstance(elem, object_type): return serialize_fn_map[object_type]( elem, serialize_fn_map=serialize_fn_map ) if isinstance(elem, (list, tuple)): return elem_type( [_serialize(e, serialize_fn_map=serialize_fn_map) for e in elem] ) # element does not need to be seralized, or type not supported return elem