Change-Id: I645eea079d9098c9bd1ad76bd508f2e9fc3c8c11 Signed-off-by: Takashi Kajinami <kajinamit@oss.nttdata.com>
489 lines
12 KiB
Python
489 lines
12 KiB
Python
# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
|
# not use this file except in compliance with the License. You may obtain
|
|
# a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
# License for the specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
import contextlib
|
|
import os
|
|
import string
|
|
import threading
|
|
import time
|
|
|
|
import etcd3gw
|
|
from oslo_utils import timeutils
|
|
import redis
|
|
|
|
from taskflow import exceptions
|
|
from taskflow.listeners import capturing
|
|
from taskflow.persistence.backends import impl_memory
|
|
from taskflow import retry
|
|
from taskflow import task
|
|
from taskflow.types import failure
|
|
from taskflow.utils import kazoo_utils
|
|
from taskflow.utils import redis_utils
|
|
|
|
ETCD_PORT = int(os.getenv("TASKFLOW_TEST_ETCD_PORT", 2379))
|
|
REDIS_PORT = int(os.getenv("TASKFLOW_TEST_REDIS_PORT", 6379))
|
|
ZK_PORT = int(os.getenv("TASKFLOW_TEST_ZOOKEEPER_PORT", 2181))
|
|
ZK_TEST_CONFIG = {
|
|
'timeout': 1.0,
|
|
'hosts': ["localhost:%d" % ZK_PORT],
|
|
}
|
|
# If latches/events take longer than this to become empty/set, something is
|
|
# usually wrong and should be debugged instead of deadlocking...
|
|
WAIT_TIMEOUT = 300
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def wrap_all_failures():
|
|
"""Convert any exceptions to WrappedFailure.
|
|
|
|
When you expect several failures, it may be convenient
|
|
to wrap any exception with WrappedFailure in order to
|
|
unify error handling.
|
|
"""
|
|
try:
|
|
yield
|
|
except Exception:
|
|
raise exceptions.WrappedFailure([failure.Failure()])
|
|
|
|
|
|
def zookeeper_available(min_version, timeout=30):
|
|
url = os.getenv("TAKSFLOW_TEST_URL")
|
|
if url is not None:
|
|
return url.startswith("zookeeper://")
|
|
|
|
client = kazoo_utils.make_client(ZK_TEST_CONFIG.copy())
|
|
try:
|
|
# NOTE(imelnikov): 30 seconds we should be enough for localhost
|
|
client.start(timeout=float(timeout))
|
|
if min_version:
|
|
zk_ver = client.server_version()
|
|
if zk_ver >= min_version:
|
|
return True
|
|
else:
|
|
return False
|
|
else:
|
|
return True
|
|
except Exception:
|
|
return False
|
|
finally:
|
|
kazoo_utils.finalize_client(client)
|
|
|
|
|
|
def redis_available(min_version):
|
|
url = os.getenv("TAKSFLOW_TEST_URL")
|
|
if url is not None:
|
|
return url.startswith("redis://")
|
|
|
|
client = redis.Redis(port=REDIS_PORT)
|
|
try:
|
|
client.ping()
|
|
except Exception:
|
|
return False
|
|
else:
|
|
ok, redis_version = redis_utils.is_server_new_enough(client,
|
|
min_version)
|
|
return ok
|
|
|
|
|
|
def etcd_available():
|
|
url = os.getenv("TAKSFLOW_TEST_URL")
|
|
if url is not None:
|
|
return url.startswith("etcd://")
|
|
|
|
client = etcd3gw.Etcd3Client(port=ETCD_PORT)
|
|
try:
|
|
client.get("/")
|
|
except Exception:
|
|
return False
|
|
return True
|
|
|
|
|
|
class NoopRetry(retry.AlwaysRevert):
|
|
pass
|
|
|
|
|
|
class NoopTask(task.Task):
|
|
|
|
def execute(self):
|
|
pass
|
|
|
|
|
|
class DummyTask(task.Task):
|
|
|
|
def execute(self, context, *args, **kwargs):
|
|
pass
|
|
|
|
|
|
class EmittingTask(task.Task):
|
|
TASK_EVENTS = (task.EVENT_UPDATE_PROGRESS, 'hi')
|
|
|
|
def execute(self, *args, **kwargs):
|
|
self.notifier.notify('hi',
|
|
details={'sent_on': timeutils.utcnow(),
|
|
'args': args, 'kwargs': kwargs})
|
|
|
|
|
|
class AddOneSameProvidesRequires(task.Task):
|
|
default_provides = 'value'
|
|
|
|
def execute(self, value):
|
|
return value + 1
|
|
|
|
|
|
class AddOne(task.Task):
|
|
default_provides = 'result'
|
|
|
|
def execute(self, source):
|
|
return source + 1
|
|
|
|
|
|
class GiveBackRevert(task.Task):
|
|
|
|
def execute(self, value):
|
|
return value + 1
|
|
|
|
def revert(self, *args, **kwargs):
|
|
result = kwargs.get('result')
|
|
# If this somehow fails, timeout, or other don't send back a
|
|
# valid result...
|
|
if isinstance(result, int):
|
|
return result + 1
|
|
|
|
|
|
class FakeTask:
|
|
|
|
def execute(self, **kwargs):
|
|
pass
|
|
|
|
|
|
class LongArgNameTask(task.Task):
|
|
|
|
def execute(self, long_arg_name):
|
|
return long_arg_name
|
|
|
|
|
|
RUNTIME_ERROR_CLASSES = ['RuntimeError', 'Exception', 'BaseException',
|
|
'object']
|
|
|
|
|
|
class ProvidesRequiresTask(task.Task):
|
|
def __init__(self, name, provides, requires, return_tuple=True):
|
|
super().__init__(name=name, provides=provides, requires=requires)
|
|
self.return_tuple = isinstance(provides, (tuple, list))
|
|
|
|
def execute(self, *args, **kwargs):
|
|
if self.return_tuple:
|
|
return tuple(range(len(self.provides)))
|
|
else:
|
|
return {k: k for k in self.provides}
|
|
|
|
|
|
# Used to format the captured values into strings (which are easier to
|
|
# check later in tests)...
|
|
LOOKUP_NAME_POSTFIX = {
|
|
capturing.CaptureListener.TASK: ('.t', 'task_name'),
|
|
capturing.CaptureListener.RETRY: ('.r', 'retry_name'),
|
|
capturing.CaptureListener.FLOW: ('.f', 'flow_name'),
|
|
}
|
|
|
|
|
|
class CaptureListener(capturing.CaptureListener):
|
|
|
|
@staticmethod
|
|
def _format_capture(kind, state, details):
|
|
name_postfix, name_key = LOOKUP_NAME_POSTFIX[kind]
|
|
name = details[name_key] + name_postfix
|
|
if 'result' in details:
|
|
name += ' {}({})'.format(state, details['result'])
|
|
else:
|
|
name += " %s" % state
|
|
return name
|
|
|
|
|
|
class MultiProgressingTask(task.Task):
|
|
def execute(self, progress_chunks):
|
|
for chunk in progress_chunks:
|
|
self.update_progress(chunk)
|
|
return len(progress_chunks)
|
|
|
|
|
|
class ProgressingTask(task.Task):
|
|
def execute(self, **kwargs):
|
|
self.update_progress(0.0)
|
|
self.update_progress(1.0)
|
|
return 5
|
|
|
|
def revert(self, **kwargs):
|
|
self.update_progress(0)
|
|
self.update_progress(1.0)
|
|
|
|
|
|
class FailingTask(ProgressingTask):
|
|
def execute(self, **kwargs):
|
|
self.update_progress(0)
|
|
self.update_progress(0.99)
|
|
raise RuntimeError('Woot!')
|
|
|
|
|
|
class SimpleTask(task.Task):
|
|
def execute(self, time_sleep=0, **kwargs):
|
|
time.sleep(time_sleep)
|
|
|
|
|
|
class SuccessAfter3Sec(task.Task):
|
|
def execute(self, start_time, **kwargs):
|
|
now = time.time()
|
|
if now - start_time >= 3:
|
|
return None
|
|
raise RuntimeError('Woot!')
|
|
|
|
|
|
class RetryFiveTimes(retry.Times):
|
|
def on_failure(self, history, *args, **kwargs):
|
|
if len(history) < 5:
|
|
time.sleep(1)
|
|
return retry.RETRY
|
|
return retry.REVERT_ALL
|
|
|
|
|
|
class AlwaysRetry(retry.Times):
|
|
def on_failure(self, history, *args, **kwargs):
|
|
return retry.RETRY
|
|
|
|
|
|
class OptionalTask(task.Task):
|
|
def execute(self, a, b=5):
|
|
result = a * b
|
|
return result
|
|
|
|
|
|
class TaskWithFailure(task.Task):
|
|
|
|
def execute(self, **kwargs):
|
|
raise RuntimeError('Woot!')
|
|
|
|
|
|
class FailingTaskWithOneArg(ProgressingTask):
|
|
def execute(self, x, **kwargs):
|
|
raise RuntimeError('Woot with %s' % x)
|
|
|
|
|
|
class NastyTask(task.Task):
|
|
|
|
def execute(self, **kwargs):
|
|
pass
|
|
|
|
def revert(self, **kwargs):
|
|
raise RuntimeError('Gotcha!')
|
|
|
|
|
|
class NastyFailingTask(NastyTask):
|
|
def execute(self, **kwargs):
|
|
raise RuntimeError('Woot!')
|
|
|
|
|
|
class TaskNoRequiresNoReturns(task.Task):
|
|
|
|
def execute(self, **kwargs):
|
|
pass
|
|
|
|
def revert(self, **kwargs):
|
|
pass
|
|
|
|
|
|
class TaskOneArg(task.Task):
|
|
|
|
def execute(self, x, **kwargs):
|
|
pass
|
|
|
|
def revert(self, x, **kwargs):
|
|
pass
|
|
|
|
|
|
class TaskMultiArg(task.Task):
|
|
|
|
def execute(self, x, y, z, **kwargs):
|
|
pass
|
|
|
|
def revert(self, x, y, z, **kwargs):
|
|
pass
|
|
|
|
|
|
class TaskOneReturn(task.Task):
|
|
|
|
def execute(self, **kwargs):
|
|
return 1
|
|
|
|
def revert(self, **kwargs):
|
|
pass
|
|
|
|
|
|
class TaskMultiReturn(task.Task):
|
|
|
|
def execute(self, **kwargs):
|
|
return 1, 3, 5
|
|
|
|
def revert(self, **kwargs):
|
|
pass
|
|
|
|
|
|
class TaskOneArgOneReturn(task.Task):
|
|
|
|
def execute(self, x, **kwargs):
|
|
return 1
|
|
|
|
def revert(self, x, **kwargs):
|
|
pass
|
|
|
|
|
|
class TaskMultiArgOneReturn(task.Task):
|
|
|
|
def execute(self, x, y, z, **kwargs):
|
|
return x + y + z
|
|
|
|
def revert(self, x, y, z, **kwargs):
|
|
pass
|
|
|
|
|
|
class TaskMultiArgMultiReturn(task.Task):
|
|
|
|
def execute(self, x, y, z, **kwargs):
|
|
return 1, 3, 5
|
|
|
|
def revert(self, x, y, z, **kwargs):
|
|
pass
|
|
|
|
|
|
class TaskMultiDict(task.Task):
|
|
|
|
def execute(self):
|
|
output = {}
|
|
for i, k in enumerate(sorted(self.provides)):
|
|
output[k] = i
|
|
return output
|
|
|
|
|
|
class NeverRunningTask(task.Task):
|
|
def execute(self, **kwargs):
|
|
assert False, 'This method should not be called'
|
|
|
|
def revert(self, **kwargs):
|
|
assert False, 'This method should not be called'
|
|
|
|
|
|
class TaskRevertExtraArgs(task.Task):
|
|
def execute(self, **kwargs):
|
|
raise exceptions.ExecutionFailure("We want to force a revert here")
|
|
|
|
def revert(self, revert_arg, flow_failures, result, **kwargs):
|
|
pass
|
|
|
|
|
|
class SleepTask(task.Task):
|
|
def execute(self, duration, **kwargs):
|
|
time.sleep(duration)
|
|
|
|
|
|
class EngineTestBase:
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.backend = impl_memory.MemoryBackend(conf={})
|
|
|
|
def tearDown(self):
|
|
EngineTestBase.values = None
|
|
with contextlib.closing(self.backend) as be:
|
|
with contextlib.closing(be.get_connection()) as conn:
|
|
conn.clear_all()
|
|
super().tearDown()
|
|
|
|
def _make_engine(self, flow, **kwargs):
|
|
raise exceptions.NotImplementedError("_make_engine() must be"
|
|
" overridden if an engine is"
|
|
" desired")
|
|
|
|
|
|
class FailureMatcher:
|
|
"""Needed for failure objects comparison."""
|
|
|
|
def __init__(self, failure):
|
|
self._failure = failure
|
|
|
|
def __repr__(self):
|
|
return str(self._failure)
|
|
|
|
def __eq__(self, other):
|
|
return self._failure.matches(other)
|
|
|
|
def __ne__(self, other):
|
|
return not self.__eq__(other)
|
|
|
|
|
|
class OneReturnRetry(retry.AlwaysRevert):
|
|
|
|
def execute(self, **kwargs):
|
|
return 1
|
|
|
|
def revert(self, **kwargs):
|
|
pass
|
|
|
|
|
|
class ConditionalTask(ProgressingTask):
|
|
|
|
def execute(self, x, y):
|
|
super().execute()
|
|
if x != y:
|
|
raise RuntimeError('Woot!')
|
|
|
|
|
|
class WaitForOneFromTask(ProgressingTask):
|
|
|
|
def __init__(self, name, wait_for, wait_states, **kwargs):
|
|
super().__init__(name, **kwargs)
|
|
if isinstance(wait_for, str):
|
|
self.wait_for = [wait_for]
|
|
else:
|
|
self.wait_for = wait_for
|
|
if isinstance(wait_states, str):
|
|
self.wait_states = [wait_states]
|
|
else:
|
|
self.wait_states = wait_states
|
|
self.event = threading.Event()
|
|
|
|
def execute(self):
|
|
if not self.event.wait(WAIT_TIMEOUT):
|
|
raise RuntimeError('%s second timeout occurred while waiting '
|
|
'for %s to change state to %s'
|
|
% (WAIT_TIMEOUT, self.wait_for,
|
|
self.wait_states))
|
|
return super().execute()
|
|
|
|
def callback(self, state, details):
|
|
name = details.get('task_name', None)
|
|
if name not in self.wait_for or state not in self.wait_states:
|
|
return
|
|
self.event.set()
|
|
|
|
|
|
def make_many(amount, task_cls=DummyTask, offset=0):
|
|
name_pool = string.ascii_lowercase + string.ascii_uppercase
|
|
tasks = []
|
|
while amount > 0:
|
|
if offset >= len(name_pool):
|
|
raise AssertionError('Name pool size to small (%s < %s)'
|
|
% (len(name_pool), offset + 1))
|
|
tasks.append(task_cls(name=name_pool[offset]))
|
|
offset += 1
|
|
amount -= 1
|
|
return tasks
|