Source code for spines.models.base

# -*- coding: utf-8 -*-
"""
Model class for the spines package.
"""
from __future__ import annotations

from abc import abstractmethod
from typing import Any
from typing import Dict
from typing import List
from typing import Type

from ..core.base import BaseObjectException
from ..decorators.output import negate
from ..utils.file import load_pickle
from ..utils.file import save_pickle
from ..parameters.base import HyperParameter
from ..parameters.decorators import finalize_post
from ..parameters.decorators import finalize_pre
from ..parameters.store import ParameterStore
from ..transforms.base import Transform


[docs]class Model(Transform): """ Spines primary Model class """ __hyperparam_store__: Type[ParameterStore] = ParameterStore def __init__(self, *args, **kwargs) -> None: self._hyper_params = self._create_store( self.__hyperparam_store__, HyperParameter ) return super().__init__(*args, **kwargs) @property def hyper_parameters(self) -> ParameterStore: """ParameterStore: Hyper-parameters which are currently set.""" return self._hyper_params
[docs] def set_hyper_params(self, **hyper_params) -> None: """Sets the values of this model's hyper-parameters Parameters ---------- hyper_params Hyper-parameter values to set. Raises ------ InvalidParameterException If one of the given hyper-parameter values is not valid. """ self._hyper_params.update(hyper_params) return
[docs] def get_hyper_params(self) -> Dict[str, Any]: """Gets the current hyper-parameter values Returns ------- dict Copy of the currently set hyper-parameter values. See Also -------- hyper_parameters, set_hyper_params """ return self._hyper_params.values
[docs] def set_hyper_parameter(self, name: str, value: Any) -> None: """Sets a hyper-parameter value Sets a hyper-parameter's value if the given `hyper_param` and `value` are valid. Parameters ---------- name : str Hyper-parameter to set value for. value Value to set. Raises ------ MissingParameterException If the given `name` hyper-parameter does not exist. InvalidParameterException If the given `value` is not valid for the specified hyper-parameter. See Also -------- hyper_parameters, set_hyper_params """ self._hyper_parameters[name] = value return
[docs] def unset_hyper_parameter(self, name: str) -> Any: """Un-sets a hyper-parameter Un-sets the specified hyper-parameter's value from the set of hyper-parameters and returns the previously set value. Parameters ---------- name : str Name of the hyper-parameter to clear the value for. Returns ------- object Previously set value of the hyper-parameter. Raises ------ MissingParameterException If the given `name` hyper-parameter does not exist. See Also -------- hyper_parameters, set_hyper_params """ return self._hyper_parameters.pop(name)
[docs] def fit(self, *args, **kwargs) -> None: """Fits the model Generally this method is used for a single iteration of model fitting (and for simple models this may be the only call which is required). The :obj:`train` method can call this function multiple times and update the model iteratively (where that approach is appropriate). Parameters ---------- args : optional Arguments to use in fit call. kwargs : optional Any additional keyword arguments to use in fit call. Returns ------- :obj:`None` or :obj:`dict` Either returns `None` if adjustments to the model's parameters happen internally, otherwise returns the dictionary of updated parameters to apply. See Also -------- train, transform """ return super().fit(*args, **kwargs)
[docs] def transform(self, *args, **kwargs) -> Any: """Transforms the given input data Parameters ---------- args : optional Additional arguments to pass to predict call. kwargs : optional Additional keyword arguments to pass to predict call. Returns ------- object Transformed inputs. """ return self.predict(*args, **kwargs)
[docs] @abstractmethod def predict(self, *args, **kwargs) -> Any: """Predict outputs for the given inputs Parameters ---------- args : optional Additional arguments to pass to predict call. kwargs : optional Additional keyword arguments to pass to predict call. Returns ------- object Predictions from the given data. """ pass
def _save_helper(self, dir_path: str) -> List[str]: """Saves Model objects to the specified directory""" ret = super(Model, self)._save_helper(dir_path) ret.append( save_pickle(self._hyper_params, dir_path, 'hyperparameters') ) return ret @classmethod def _load_helper(cls, dir_path: str, new: bool) -> Model: """Helper function for loading a Model from file""" instance = super(Model, cls)._load_helper(dir_path, new) instance._hyper_params = load_pickle(dir_path, 'hyperparameters') return instance def _modify_methods(self, *args, **kwargs) -> None: """Modifies the model's functions in-place on object creation""" super(Model, self)._modify_methods(*args, **kwargs) self.fit = finalize_pre(self.fit, self._hyper_params) self.fit = finalize_post(self.fit, self._params) if (hasattr(self.error, '__overridden__') and not hasattr(self.score, '__overridden__')): self.score = negate(self.error) elif (hasattr(self.score, '__overridden__') and not hasattr(self.error, '__overridden__')): self.error = negate(self.score) return
class ModelException(BaseObjectException): """ Base class for Model exceptions. """ pass