# -*- coding: utf-8 -*-
"""Base class for State objects."""
import inspect
import os
from datetime import datetime
from typing import List
from ._keys import STATES_KEY
from .. import ConfigDb, LOG
from .._events.event_queue import EventQueue
from .._events.pubsub import get_subscribers, publish, subscribe
from ..utils.datetime_utils import datetime_from_isoformat
DB = ConfigDb()
class StateObject:
"""Base class for state objects (service state & sdp state)."""
def __init__(self, object_id: str, allowed_states: List[str],
allowed_transitions: dict, allowed_target_states: dict):
"""Initialise a state object.
Args:
allowed_states (List[str]): List of allowed states.
allowed_transitions (dict): Dict of allowed state transitions
allowed_target_states (dict): Dict of allowed target states
"""
self._id = object_id
self._type = STATES_KEY # object type
self._key = '{}:{}'.format(STATES_KEY, self._id)
self._allowed_states = [state.lower() for state in allowed_states]
self._allowed_transitions = self._dict_lower(allowed_transitions)
self._allowed_target_states = self._dict_lower(allowed_target_states)
if not DB.key_exists(self._key):
# DB.set_hash_values(self._key, self._initialise())
DB.save_dict(self._key, self._initialise())
@property
def id(self) -> str:
"""Return the object id."""
return self._id
@property
def allowed_states(self) -> List[str]:
"""Get list of allowed object states."""
return self._allowed_states
@property
def allowed_state_transitions(self) -> dict:
"""Get dictionary of allowed state transitions."""
return self._allowed_transitions
@property
def allowed_target_states(self) -> dict:
"""Get dictionary of allowed target states / commands."""
return self._allowed_target_states
@property
def current_state(self) -> str:
"""Get the current state."""
return DB.get_hash_value(self._key, 'current_state')
@current_state.setter
def current_state(self, value):
"""Set the current state."""
self.update_current_state(value)
def is_target_state_allowed(self, value):
"""Test if a transition is allowed.
Args:
value (str): New value for target state
Returns:
bool, transition is allowed
"""
return value in self._allowed_target_states[self.current_state]
@property
def target_state(self) -> str:
"""Get the target state."""
return DB.get_hash_value(self._key, 'target_state')
@target_state.setter
def target_state(self, value):
"""Set the target state."""
self.update_target_state(value)
@property
def current_timestamp(self) -> datetime:
"""Get the current state timestamp."""
timestamp = DB.get_hash_value(self._key, 'current_timestamp')
return datetime_from_isoformat(timestamp)
@property
def target_timestamp(self) -> datetime:
"""Get the target state timestamp."""
timestamp = DB.get_hash_value(self._key, 'target_timestamp')
return datetime_from_isoformat(timestamp)
def update_target_state(self, value: str, force: bool = True) -> datetime:
"""Set the target state.
Args:
value (str): New value for target state
force (bool): If true, ignore allowed transitions
Returns:
datetime, update timestamp
Raises:
RuntimeError, if it is not possible to currently set the target
state.
ValueError, if the specified target stat is not allowed.
"""
value = value.lower()
if not force:
current_state = self.current_state
if current_state == 'unknown':
raise RuntimeError("Unable to set target state when current "
"state is 'unknown'")
allowed_target_states = self._allowed_target_states[current_state]
LOG.debug('Updating target state of %s to %s', self._id, value)
if value not in allowed_target_states:
raise ValueError("Invalid target state: '{}'. {} can be "
"commanded to states: {}".
format(value, current_state,
allowed_target_states))
return self._update_state('target', value)
def update_current_state(self, value: str,
force: bool = False) -> datetime:
"""Update the current state.
Args:
value (str): New value for sdp state
force (bool): If true, ignore allowed transitions
Returns:
datetime, update timestamp
Raises:
ValueError: If the specified current state is not allowed.
"""
value = value.lower()
if not force:
current_state = self.current_state
# IF the current state is unknown, it can be set to any of the
# allowed states, otherwise only allow certain transitions.
if current_state == 'unknown':
allowed_transitions = self._allowed_states
else:
allowed_transitions = self._allowed_transitions[current_state]
allowed_transitions.append(current_state)
LOG.debug('Updating current state of %s to %s', self._id, value)
if value not in allowed_transitions:
raise ValueError("Invalid current state update: '{}'. '{}' "
"can be transitioned to states: {}"
.format(value, current_state,
allowed_transitions))
return self._update_state('current', value)
###########################################################################
# Pub/Sub functions
###########################################################################
@staticmethod
def subscribe(subscriber: str) -> EventQueue:
"""Subscribe to state events.
Args:
subscriber (str): Subscriber name.
Returns:
events.EventQueue, Event queue object for querying events.
"""
return subscribe(STATES_KEY, subscriber)
@staticmethod
def get_subscribers() -> List[str]:
"""Get the list of subscribers to state events.
Returns:
List[str], list of subscriber names.
"""
return get_subscribers(STATES_KEY)
def publish(self, event_type: str, event_data: dict = None):
"""Publish an state event.
Args:
event_type (str): Type of event.
event_data (dict, optional): Event data.
"""
_stack = inspect.stack()
_origin = (os.path.basename(_stack[3][1]) + '::' +
_stack[3][3]+'::L{}'.format(_stack[3][2]))
publish(event_type=event_type,
event_data=event_data,
object_type=self._type,
object_id=self._id,
object_key=self._key,
origin=_origin)
def get_event_queue(self, subscriber: str):
"""Get an event queue for the specified subscriber."""
return EventQueue(self._type, subscriber)
###########################################################################
# Private functions
###########################################################################
def _initialise(self, initial_state: str = 'unknown') -> dict:
"""Return a dictionary used to initialise a state object.
This method is used to obtain a dictionary/hash describing the initial
state of SDP or a service in SDP.
Args:
initial_state (str): Initial state.
Returns:
dict, Initial state configuration
"""
initial_state = initial_state.lower()
if initial_state != 'unknown' and \
initial_state not in self._allowed_states:
raise ValueError('Invalid initial state: {}'.format(initial_state))
_initial_state = dict(
current_state=initial_state,
target_state=initial_state,
current_timestamp=datetime.utcnow().isoformat(),
target_timestamp=datetime.utcnow().isoformat())
return _initial_state
def _update_state(self, state_type: str, value: str) -> datetime:
"""Update the state of type specified (current or target).
Args:
state_type(str): Type of state to update, current or target.
value (str): New state value.
Returns:
timestamp, current time
"""
timestamp = datetime.utcnow()
field = '{}_state'.format(state_type)
old_state = DB.get_hash_value(self._key, field)
DB.set_hash_value(self._key, field, value, pipeline=True)
DB.set_hash_value(self._key, '{}_timestamp'.format(state_type),
timestamp.isoformat(), pipeline=True)
DB.execute()
# Publish an event to notify subscribers of the change in state
self.publish('{}_state_updated'.format(state_type),
event_data=dict(state=value, old_state=old_state))
return timestamp
@staticmethod
def _dict_lower(dictionary: dict):
"""Convert allowed state transitions / target states to lowercase."""
return {key.lower(): [value.lower() for value in value]
for key, value in dictionary.items()}
def __repr__(self):
"""Get a string representation of a State object."""
return self._id