# -*- coding: utf-8 -*-
"""
Base classes for core spines library.
"""
from __future__ import annotations
from abc import ABC
import tempfile
from typing import List
from typing import Mapping
from typing import Type
from typing import Union
from ..decorators.mark import override
from ..parameters.base import Parameter
from ..parameters.store import ParameterStore
from ..utils.file import extract_archive
from ..utils.file import save_archive
from ..utils.file import save_pickle
from ..utils.file import load_pickle
from ..utils.object import get_overridden_functions
[docs]class BaseObject(ABC):
"""
Base object class for all spines components
"""
__version__ = None
__param_store__: Type[ParameterStore] = ParameterStore
def __init__(self, *args, **kwargs) -> None:
self._params = self._create_store(
self.__param_store__, Parameter
)
self._modify_methods()
return
def __str__(self) -> str:
return self.__class__.__name__
def __repr__(self) -> str:
return '<%s version="%s" parameters="%s">' % (
self.__class__.__name__, self.__version__,
', '.join(sorted(self.parameters.keys()))
)
@property
def parameters(self) -> ParameterStore:
"""ParameterStore: Parameters in this object."""
return self._params
[docs] def set_params(self, **params) -> None:
"""Sets the values for this model's parameters
Parameters
----------
params
Parameters and values to set.
Raises
------
InvalidParameterException
If the given `name` or `value` are not valid.
"""
self._params.update(params)
return
[docs] def get_params(self) -> Mapping:
"""Gets a copy of this models parameters
Returns
-------
dict
Copy of currently set parameter names and values.
"""
return self._params.values
[docs] def set_parameter(self, name: str, value: Any) -> None:
"""Sets a parameter value
Will add the given `param` and `value` to the parameters if
they are valid, throws an exception if they are not.
Parameters
----------
name : str
Parameter to set the value for.
value
Value to set.
Raises
------
InvalidParameterException
If the given `name` or `value` are not valid.
See Also
--------
parameters
"""
self._params[name] = value
return
[docs] def unset_parameter(self, name: str) -> Any:
"""Unsets a parameter value
Removes the specified parameter's value from the parameter
values if it is part of the parameter set and returns its
current value.
Parameters
----------
name : str
Name of the parameter whose value needs to be un-set.
Returns
-------
object
Previously set value of the parameter.
Raises
------
MissingParameterException
If the parameter to remove does not exist in the set of
parameters.
See Also
--------
parameters
"""
return self._params.pop(name)
[docs] def save(
self, path: Union[None, str] = None, fmt: Union[None, str] = None
) -> str:
"""Saves this object to file
Parameters
----------
path : str, optional
File path to save this object to.
fmt : str, optional
Format to save this object with.
Returns
-------
str
Path to the saved object file.
"""
if path is None:
self._get_file_path()
with tempfile.TemporaryDirectory(prefix='spines-') as tmp_dir:
files = self._save_helper(tmp_dir)
return save_archive(path, files, fmt=fmt)
def _save_helper(self, dir_path: str) -> List[str]:
"""Saves the relevant parts of this object to file
Parameters
----------
dir_path : str
Path to the directory to save the files to.
Returns
-------
:obj:`list` of :obj:`str`
List of files saved.
"""
ret = list()
ret.append(save_pickle(self.__class__, dir_path, 'class'))
ret.append(save_pickle(self._params, dir_path, 'parameters'))
return ret
[docs] @classmethod
def load(
cls, path: str, fmt: Union[None, str] = None, new: bool = False
) -> BaseObject:
"""Loads an object from file
Parameters
----------
path : str
Path to the file to load from.
fmt : str, optional
Format to use when loading the file (default is :obj:`None`
which will infer based on the `path`, if possible).
new : bool, optional
Whether or not to create a new instance from this (the
calling) class or to use the stored class object (default is
:obj:`False`, use the saved version).
Returns
-------
BaseObject
The new object loaded from file.
"""
with tempfile.TemporaryDirectory(prefix='spines-') as tmp_dir:
extract_archive(path, tmp_dir, fmt=fmt)
return cls._load_helper(tmp_dir, new)
@classmethod
def _load_helper(cls, dir_path: str, new: bool) -> BaseObject:
"""Loads the various files into a new object"""
if new:
instance = cls()
else:
instance = load_pickle(dir_path, 'class')()
instance._params = load_pickle(dir_path, 'parameters')
return instance
def _get_file_path(self) -> str:
"""Gets the default file path for saving to"""
pass
def _modify_methods(self, *args, **kwargs) -> None:
"""Helper function to modify this classes methods"""
self._mark_overridden_methods()
return
@classmethod
def _create_store(cls, store_cls, param_cls) -> ParameterStore:
"""Creates and instance of the parameter store"""
store = store_cls()
for attr in cls.__dict__.values():
if isinstance(attr, param_cls):
store.add(attr)
return store
def _mark_overridden_methods(self) -> None:
"""Marks the methods overridden in this object's implementation
"""
base_cls = type(self).__bases__[-1]
for method in get_overridden_functions(base_cls, self):
setattr(self, method, override(getattr(self, method)))
return
[docs]class BaseObjectException(Exception):
"""
Base exception class for spines objects.
"""
pass