#
# Author: Guillermo Gonzalez <guillermo.gonzalez@canonical.com>
#
# Copyright 2009 Canonical Ltd.
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License version 3, as published
# by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranties of
# MERCHANTABILITY, SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR
# PURPOSE.  See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program.  If not, see <http://www.gnu.org/licenses/>.
""" Base tests cases and test utilities """
from __future__ import with_statement

import contextlib
import logging
import os
import shutil
import itertools

from ubuntuone.syncdaemon import (
    config,
    action_queue,
    event_queue,
    filesystem_manager as fs_manager,
    interfaces,
    volume_manager,
    main,
    local_rescan,
    tritcask,
)
from ubuntuone.syncdaemon import logger
logger.init()
from twisted.internet import defer
from twisted.trial.unittest import TestCase as TwistedTestCase
from zope.interface import implements
from zope.interface.verify import verifyObject


FAKED_CREDENTIALS = {'consumer_key': 'faked_consumer_key',
                     'consumer_secret': 'faked_consumer_secret',
                     'token': 'faked_token',
                     'token_secret': 'faked_token_secret',
                     'token_name': 'Test me please'}


@contextlib.contextmanager
def environ(env_var, new_value):
    """context manager to replace/add an environ value"""
    old_value = os.environ.get(env_var, None)
    os.environ[env_var] = new_value
    yield
    if old_value is None:
        os.environ.pop(env_var)
    else:
        os.environ[env_var] = old_value


class FakeHashQueue(object):
    """A fake hash queue"""
    def __init__(self, eq):
        self.eq = eq

    def empty(self):
        """are we empty? sure we are"""
        return True

    def shutdown(self):
        """go away? I'l barely *here*!"""
        pass

    def __len__(self):
        """ length is 0. we are empty, right?"""
        return 0

    def insert(self, path, mdid):
        """Fake insert."""
        self.eq.push('HQ_HASH_NEW', path=path, hash='',
                     crc32='', size=0, stat=os.stat(path))


class FakeMark(object):
    """A fake Mark Shuttleworth..."""
    def stop(self):
        """...that only knows how to stop"""


class FakeExternalInterface(object):
    """A fake DBusInterface..."""

    def shutdown(self, with_restart=False):
        """...that only knows how to go away"""

    def _request_token(self, *args, **kwargs):
        """Return a token which is a fixed set of credentials."""
        return FAKED_CREDENTIALS


class FakeActionQueue(object):
    """Stub implementation."""

    implements(interfaces.IActionQueue)

    def __init__(self, eq, *args, **kwargs):
        """ Creates the instance """
        self.eq = self.event_queue = eq
        self.uuid_map = action_queue.DeferredMap()
        self.queue = action_queue.RequestQueue(self)

        # throttling attributes
        self.readLimit = None
        self.writeLimit = None
        self.throttling_enabled = False

    def __setattr__(self, attr, value):
        """Custom __setattr__ that check the interface.

        After setting a callable attribute, verify the interface.
        """
        r = super(FakeActionQueue, self).__setattr__(attr, value)
        if callable(value):
            # check that AQ implements IActionQueue.
            verifyObject(interfaces.IActionQueue, self)
        return r

    # IContentQueue
    def cancel_download(self, share_id, node_id):
        """Stub implementation."""

    def cancel_upload(self, share_id, node_id):
        """Stub implementation."""

    def download(self, share_id, node_id, server_hash, path, fileobj_factory):
        """Stub implementation."""

    def upload(self, share_id, node_id, previous_hash, hash, crc32,
               size, path, fileobj_factory, tempfile_factory=None):
        """Stub implementation."""

    # IMetaQueue
    def connect(self, host=None, port=None, user_ssl=False):
        """Just send connect!."""
        self.eq.push('SYS_CONNECTION_MADE')

    def enable_throttling(self):
        """We have throttling enabled now."""
        self.throttling_enabled = True

    def disable_throttling(self):
        """We have throttling disabled now."""
        self.throttling_enabled = False

    def answer_share(self, share_id, answer):
        """Send the event."""
        self.eq.push('AQ_ANSWER_SHARE_OK', share_id=share_id, answer=answer)

    def disconnect(self, *a, **k):
        """Stub implementation."""

    cancel_download = cancel_upload = download = upload = make_dir = disconnect
    make_file = move = unlink = list_shares = disconnect
    list_volumes = create_share = create_udf = inquire_free_space =  disconnect
    inquire_account_info = delete_volume = change_public_access = disconnect
    query_volumes = get_delta = rescan_from_scratch = delete_share = disconnect
    node_is_with_queued_move = cleanup = get_public_files = disconnect


class FakeStatusListener(object):
    """A fake StatusListener."""

    show_all_notifications = True


class FakeMain(main.Main):
    """ A fake Main class to setup the tests """

    _fake_AQ_class = FakeActionQueue
    _fake_AQ_params = ()
    _sync_class = None

    # don't call Main.__init__ we take care of creating a fake main and
    # all its attributes. pylint: disable-msg=W0231
    def __init__(self, root_dir, shares_dir, data_dir, partials_dir):
        """ create the instance. """
        self.logger = logging.getLogger('ubuntuone.SyncDaemon.FakeMain')
        self.root_dir = root_dir
        self.data_dir = data_dir
        self.shares_dir = shares_dir
        self.partials_dir = partials_dir
        self.shares_dir_link = os.path.join(self.root_dir, 'Shared With Me')
        self.db = tritcask.Tritcask(os.path.join(self.data_dir, 'tritcask'))
        self.vm = volume_manager.VolumeManager(self)
        self.fs = fs_manager.FileSystemManager(
            self.data_dir, self.partials_dir, self.vm, self.db)
        self.event_q = event_queue.EventQueue(self.fs)
        self.fs.register_eq(self.event_q)
        self.action_q = self._fake_AQ_class(self.event_q, self,
                                            *self._fake_AQ_params)
        self.state_manager = main.StateManager(self, 2)
        if self._sync_class is not None:
            self.sync = self._sync_class(self)
        self.event_q.subscribe(self.vm)
        self.vm.init_root()
        self.hash_q = FakeHashQueue(self.event_q)
        self.mark = FakeMark()
        self.external = FakeExternalInterface()
        self.lr = local_rescan.LocalRescan(self.vm, self.fs,
                                           self.event_q, self.action_q)
        self.status_listener = FakeStatusListener()

    def _connect_aq(self, _):
        """Connect the fake action queue."""
        self.action_q.connect()

    def _disconnect_aq(self):
        """Disconnect the fake action queue."""
        self.action_q.disconnect()

    def check_version(self):
        """Check the client protocol version matches that of the server."""
        self.event_q.push('SYS_PROTOCOL_VERSION_OK')

    def authenticate(self):
        """Do the OAuth dance."""
        self.event_q.push('SYS_AUTH_OK')

    def set_capabilities(self):
        """Set the capabilities."""
        self.event_q.push('SYS_SET_CAPABILITIES_OK')

    def get_root(self, root_mdid):
        """Ask que AQ for our root's uuid."""
        return defer.succeed('root_uuid')

    def server_rescan(self):
        """Do the server rescan? naaa!"""
        self.event_q.push('SYS_SERVER_RESCAN_DONE')
        return defer.succeed('root_uuid')

    def local_rescan(self):
        """Do the local rescan? naaa!"""
        self.event_q.push('SYS_LOCAL_RESCAN_DONE')
        return defer.succeed(True)


class BaseTwistedTestCase(TwistedTestCase):
    """Base TestCase with helper methods to handle temp dir.

    This class provides:
        mktemp(name): helper to create temporary dirs
        rmtree(path): support read-only shares
        makedirs(path): support read-only shares
    """

    def mktemp(self, name='temp'):
        """ Customized mktemp that accepts an optional name argument. """
        tempdir = os.path.join(self.tmpdir, name)
        if os.path.exists(tempdir):
            self.rmtree(tempdir)
        self.makedirs(tempdir)
        return tempdir

    @property
    def tmpdir(self):
        """Default tmpdir: module/class/test_method."""
        # check if we already generated the root path
        root_dir = getattr(self, '__root', None)
        if root_dir:
            return root_dir
        MAX_FILENAME = 32 # some platforms limit lengths of filenames
        base = os.path.join(self.__class__.__module__[:MAX_FILENAME],
                            self.__class__.__name__[:MAX_FILENAME],
                            self._testMethodName[:MAX_FILENAME])
        # use _trial_temp dir, it should be os.gwtcwd()
        # define the root temp dir of the testcase, pylint: disable-msg=W0201
        self.__root = os.path.join(os.getcwd(), base)
        return self.__root

    def rmtree(self, path):
        """Custom rmtree that handle ro parent(s) and childs."""
        if not os.path.exists(path):
            return
        # change perms to rw, so we can delete the temp dir
        if path != getattr(self, '__root', None):
            os.chmod(os.path.dirname(path), 0755)
        if not os.access(path, os.W_OK):
            os.chmod(path, 0755)
        # pylint: disable-msg=W0612
        for dirpath, dirs, files in os.walk(path):
            for dir in dirs:
                if not os.access(os.path.join(dirpath, dir), os.W_OK):
                    os.chmod(os.path.join(dirpath, dir), 0777)
        shutil.rmtree(path)

    def makedirs(self, path):
        """Custom makedirs that handle ro parent."""
        parent = os.path.dirname(path)
        if os.path.exists(parent):
            os.chmod(parent, 0755)
        os.makedirs(path)

    def setUp(self):
        TwistedTestCase.setUp(self)
        # use the config from the branch
        self.old_get_config_files = config.get_config_files
        config.get_config_files = lambda: [os.path.join(os.environ['ROOTDIR'],
                                                   'data', 'syncdaemon.conf')]
        # fake a very basic config file with sane defaults for the tests
        self.config_file = os.path.join(self.mktemp('config'), 'syncdaemon.conf')
        with open(self.config_file, 'w') as fp:
            fp.write('[bandwidth_throttling]\n')
            fp.write('on = False\n')
            fp.write('read_limit = -1\n')
            fp.write('write_limit = -1\n')
        # invalidate the current config
        config._user_config = None
        config.get_user_config(config_file=self.config_file)

    def tearDown(self):
        """ cleanup the temp dir. """
        # invalidate the current config
        config._user_config = None
        # restore the old get_config_files
        config.get_config_files = self.old_get_config_files
        self.rmtree(os.path.dirname(self.config_file))
        root_dir = getattr(self, '__root', None)
        if root_dir:
            self.rmtree(self.__root)
        return TwistedTestCase.tearDown(self)


class FakeMainTestCase(BaseTwistedTestCase):
    """A testcase that starts up a Main instance."""

    def setUp(self):
        """Setup the infrastructure for the test."""
        BaseTwistedTestCase.setUp(self)
        self.log = logging.getLogger("ubuntuone.SyncDaemon.TEST")
        self.log.info("starting test %s.%s", self.__class__.__name__,
                      self._testMethodName)
        self.timeout = 2
        self.data_dir = self.mktemp('data_dir')
        self.partials_dir = self.mktemp('partials')
        self.root_dir = self.mktemp('root_dir')
        self.shares_dir = self.mktemp('shares_dir')
        self.main = FakeMain(self.root_dir, self.shares_dir,
                             self.data_dir, self.partials_dir)
        self.vm = self.main.vm
        self.fs = self.main.fs
        self.event_q = self.main.event_q
        self.action_q = self.main.action_q
        self.event_q.push('SYS_INIT_DONE')

    def tearDown(self):
        """Shutdown this testcase."""
        BaseTwistedTestCase.tearDown(self)
        self.main.shutdown()
        self.rmtree(self.shares_dir)
        self.rmtree(self.root_dir)
        self.rmtree(self.data_dir)
        self.rmtree(self.partials_dir)
        self.log.info("finished test %s.%s", self.__class__.__name__,
                      self._testMethodName)


class FakeVolumeManager(object):
    """ A volume manager that only knows one share, the root"""

    def __init__(self, root_path):
        """ Creates the instance"""
        self.root = volume_manager.Root(node_id="root_node_id", path=root_path)
        self.shares = {'':self.root}
        self.udfs = {}
        self.log = logging.getLogger('ubuntuone.SyncDaemon.VM-test')

    def add_share(self, share):
        """Add share to the shares dict."""
        self.shares[share.id] = share
        # if the share don't exists, create it
        if not os.path.exists(share.path):
            os.mkdir(share.path)
        # if it's a ro share, change the perms
        if not share.can_write():
            os.chmod(share.path, 0555)

    def add_udf(self, udf):
        """Add udf to the udfs dict."""
        self.udfs[udf.id] = udf

    def share_deleted(self, _):
        """Do nothing."""

    def get_volume(self, id):
        """Returns a share or a UDF."""
        try:
            return self.shares[id]
        except KeyError:
            try:
                return self.udfs[id]
            except KeyError:
                raise volume_manager.VolumeDoesNotExist(id)

    def get_volumes(self, all_volumes=False):
        """Simple get_volumes for FakeVolumeManager."""
        volumes = itertools.chain(self.shares.values(), self.udfs.values())
        for volume in volumes:
            if all_volumes or volume.active:
                yield volume

    def unsubscribe_udf(self, udf_id):
        """Mark the UDF with udf_id as unsubscribed."""
        udf = self.udfs[udf_id]
        udf.subscribed = False
        self.udfs[udf_id] = udf

    def delete_volume(self, volume_id):
        """Request the deletion of a volume."""


class FakeLogger(object):
    """Helper logging class."""
    def __init__(self):
        self.logged = dict(debug=[], warning=[], info=[])

    def _log(self, log, txt, args):
        """Really logs."""
        if args:
            txt = txt % args
        log.append(txt)

    def warning(self, txt, *args):
        """WARNING logs."""
        self._log(self.logged['warning'], txt, args)

    def debug(self, txt, *args):
        """DEBUG logs."""
        self._log(self.logged['debug'], txt, args)

    def info(self, txt, *args):
        """INFO logs."""
        self._log(self.logged['info'], txt, args)


class Listener(object):
    """Helper class to gather events."""

    def __init__(self):
        self.events = []

    def handle_default(self, event_name, **kwargs):
        """Keep record of every event."""
        self.events.append((event_name, kwargs))


class DummyClass(object):
    """Dummy class, does nothing."""

    def __getattr__(self, name):
        """Any attribute is a no-op."""
        return lambda *args, **kwargs: None
