# -*- coding: utf-8 -*-
#
# Author: John R. Lenton <john.lenton@canonical.com>
# Author: Natalia B. Bidart <natalia.bidart@canonical.com>
#
# Copyright 2009 Canonical Ltd.
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License version 3, as published
# by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranties of
# MERCHANTABILITY, SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR
# PURPOSE.  See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program.  If not, see <http://www.gnu.org/licenses/>.
"""Tests for the action queue module."""

from __future__ import with_statement

import base64
import logging
import os
import shutil
import unittest
import urllib2
import uuid

import dbus
from dbus.mainloop.glib import DBusGMainLoop
from functools import wraps
from StringIO import StringIO
from twisted.internet import defer, threads, reactor
from twisted.internet import error as twisted_error
from twisted.python.failure import DefaultException, Failure
from twisted.web import server

from contrib.testing.testcase import (
    BaseTwistedTestCase, MementoHandler, DummyClass
)

from ubuntuone.storageprotocol import client, errors, protocol_pb2, volumes
from ubuntuone.syncdaemon import states
from ubuntuone.syncdaemon.dbus_interface import DBusInterface
from ubuntuone.syncdaemon.main import Main
from ubuntuone.syncdaemon.action_queue import (
    ActionQueue, ActionQueueCommand, ChangePublicAccess, CreateUDF,
    DeleteVolume, ListDir, ListVolumes, NoisyRequestQueue, RequestQueue,
    Upload, CreateShare, GetPublicFiles,
)
from ubuntuone.syncdaemon.event_queue import EventQueue, EVENTS
from ubuntuone.syncdaemon.volume_manager import UDF


DBusInterface.test = True

PATH = u'~/Documents/pdfs/moño/'
NAME = u'UDF-me'
VOLUME = uuid.UUID('12345678-1234-1234-1234-123456789abc')
NODE = uuid.UUID('FEDCBA98-7654-3211-2345-6789ABCDEF12')
USER = u'Dude'

def fire_and_check(f, deferred, check):
    """Callback a deferred."""
    @wraps(f)
    def inner(*args, **kwargs):
        """Execute f and fire the deferred."""
        result = f(*args, **kwargs)
        error = check()
        if not error:
            deferred.callback(True)
        else:
            deferred.errback(error)
        return result
    return inner

class FakeCommand(object):
    """Yet another fake action queue command."""

    def run(self):
        """Run that just succeeds."""
        return defer.succeed(None)

    def is_runnable(self):
        """Always runnable."""
        return True


class FakedEventQueue(EventQueue):
    """Faked event queue."""

    def __init__(self, fs=None):
        """Initialize a faked event queue."""
        super(FakedEventQueue, self).__init__(fs=fs)
        self.events = []

    def push(self, event_name, *args, **kwargs):
        """Faked event pushing."""
        self.events.append((event_name, args, kwargs))
        super(FakedEventQueue, self).push(event_name, *args, **kwargs)


class FakedVolume(object):
    """Faked volume."""


class TestingProtocol(ActionQueue.protocol):
    """Protocol for testing."""

    def connectionMade(self):
        """connectionMade."""
        ActionQueue.protocol.connectionMade(self)
        # proper event is pushed
        expected = [('SYS_CONNECTION_MADE', (), {})]
        actual = self.factory.event_queue.events
        assert expected == actual, \
               'events must be %s not %s' % (expected, actual)

        self.factory.event_queue.events = [] # reset events
        self.testing_deferred.callback(True)


class BasicTestCase(BaseTwistedTestCase):
    """Basic tests to check ActionQueue."""

    def setUp(self):
        """Init."""
        BaseTwistedTestCase.setUp(self)

        self.root = self.mktemp('root')
        self.home = self.mktemp('home')
        self.data = self.mktemp('data')
        self.shares = self.mktemp('shares')
        self.partials = self.mktemp('partials')

        self.main = Main(root_dir=self.root,
                         shares_dir=self.shares,
                         data_dir=self.data,
                         partials_dir=self.partials,
                         host='127.0.0.1', port=55555,
                         dns_srv=False, ssl=False,
                         disable_ssl_verify=True,
                         realm='fake.realm',
                         mark_interval=60,
                         handshake_timeout=2,
                         glib_loop=DBusGMainLoop(set_as_default=True))

        self.action_queue = self.main.action_q
        self.action_queue.connection_timeout=3
        self.action_queue.event_queue.events = []

        def keep_a_copy(f):
            """Keep a copy of the pushed events."""
            @wraps(f)
            def recording(event_name, *args, **kwargs):
                """Keep a copy of the pushed events."""
                value = (event_name, args, kwargs)
                self.action_queue.event_queue.events.append(value)
                return f(event_name, *args, **kwargs)
            return recording

        self.main.event_q.push = keep_a_copy(self.main.event_q.push)

        self.handler = MementoHandler()
        self.handler.setLevel(logging.INFO)
        logging.getLogger('ubuntuone.SyncDaemon').addHandler(self.handler)

        # fake local rescan call to not be executed
        self.main.local_rescan = lambda: self.main.event_q.push(
                                                    'SYS_LOCAL_RESCAN_DONE')
        dbus.service.BusName.__del__ = lambda _: None

    def tearDown(self):
        """Cleanup."""
        self.main.shutdown()

        shutil.rmtree(self.root)
        shutil.rmtree(self.shares)
        shutil.rmtree(self.data)
        shutil.rmtree(self.partials)

        for record in self.handler.records:
            exc_info = getattr(record, 'exc_info', None)
            if exc_info is not None:
                raise exc_info[0], exc_info[1], exc_info[2]

        BaseTwistedTestCase.tearDown(self)

    def test_creation_requires_main(self):
        """Main instance is needed at creation time."""
        self.assertEquals(self.main, self.action_queue.main)

    def test_content_queue_has_only_one_op_per_node(self):
        """
        Check that the content queue uniquifies operations per node.
        """
        self.main.start()
        # totally fake, we don't care: the messages are only validated on run
        self.main.action_q.download('foo', 'bar', 0, 0)
        self.main.action_q.upload('foo', 'bar', 0, 0, 0, 0, 0)
        self.assertEqual(len(self.main.action_q.content_queue.waiting), 1)

    def test_content_queue_has_only_one_op_per_node_even_counting_markers(self):
        """
        Check that the content queue uniquifies operations per node
        even when some of the operations were added using markers.
        """
        self.main.start()
        self.main.action_q.download('foo', 'bar', 0, 0)
        self.main.action_q.uuid_map.set('foo', 'feh')
        self.main.action_q.uuid_map.set('bar', 'bah')
        self.main.action_q.upload('feh', 'bah', 0, 0, 0, 0, 0)
        self.assertEqual(len(self.main.action_q.content_queue.waiting), 1)

    def test_aq_resolve_uuid_maybe(self):
        """
        Check action_q.resolve_uuid_maybe does what it's supposed to
        """
        self.main.start()
        self.assertEqual(self.main.action_q.resolve_uuid_maybe('foo'), 'foo')
        self.main.action_q.uuid_map.set('foo', 'feh')
        self.assertEqual(self.main.action_q.resolve_uuid_maybe('foo'), 'feh')


class TestNoisyRQ(unittest.TestCase):
    """
    Tests for NoisyRequestQueue
    """

    def test_noisy_rq_blurts_about_head(self):
        """
        Test NRQ calls its callback when head is set
        """
        rq = NoisyRequestQueue('name', None,
                               lambda h, w:
                                   setattr(self, 'result', (h, tuple(w))))
        rq._head = 'blah'
        self.assertEqual(self.result, ('blah', ()))

    def test_noisy_rq_blurts_about_waiting(self):
        """
        Test NRQ calls its callback when the waiting queue is altered.
        """
        class BlackHole(object):
            """The universal tool."""
            __call__ = __getattr__ = lambda *_, **__: BlackHole()
        def cb(head, waiting):
            """NRQ testing callback"""
            evts.append((head, tuple(waiting)))
        evts = []
        cmd = FakeCommand()
        cmd2 = FakeCommand()
        rq = NoisyRequestQueue('name', BlackHole(), cb)
        rq.queue(cmd2)
        rq.queue_top(cmd)
        rq.run()
        self.assertEqual(evts, [(None, ()),           # __init__
                                (None, (cmd2,)),      # queue
                                (None, (cmd, cmd2)),  # queue_top
                                (cmd, (cmd2,)),       # run
                                (None, (cmd2,)),      # done
                                ])


class FactoryBaseTestCase(BaseTwistedTestCase):
    """Helper for by-pass Twisted."""

    timeout = 5

    def _start_sample_webserver(self):
        """Start a web server serving content at its root"""
        website = server.Site(None)
        webport = reactor.listenTCP(55555, website)

        transport_class = webport.transport
        def save_an_instance(skt, protocol, addr, sself, s, sreactor):
            self.server_transport = transport_class(skt, protocol, addr, sself,
                                                    s, sreactor)
            return self.server_transport
        webport.transport = save_an_instance

        self.addCleanup(webport.stopListening)
        return webport

    def _connect_factory(self):
        """Connect the instance factory."""
        self.server = self._start_sample_webserver()

        self.action_queue.protocol = TestingProtocol
        orig = self.action_queue.buildProtocol

        d = defer.Deferred()
        def faked_buildProtocol(*args, **kwargs):
            """Override buildProtocol to hook a deferred."""
            protocol = orig(*args, **kwargs)
            protocol.testing_deferred = d
            return protocol

        self.action_queue.buildProtocol = faked_buildProtocol
        self.action_queue.connect()

        return d

    def _disconnect_factory(self):
        """Disconnect the instance factory."""
        if self.action_queue.client is not None:
            orig = self.action_queue.client.connectionLost

            d = defer.Deferred()
            def faked_connectionLost(reason):
                """Receive connection lost and fire tearDown."""
                orig(reason)
                d.callback(True)

            self.action_queue.client.connectionLost = faked_connectionLost
        else:
            d = defer.succeed(True)

        if self.action_queue.connect_in_progress:
            self.action_queue.disconnect()

        return d

    def setUp(self):
        """Init."""
        BaseTwistedTestCase.setUp(self)

        kwargs = dict(event_queue=FakedEventQueue(), main=None,
                      host='127.0.0.1', port=55555, dns_srv=False,
                      connection_timeout=3)
        self.action_queue = ActionQueue(**kwargs)

    def tearDown(self):
        """Clean up."""
        self.action_queue.event_queue.shutdown()
        self.action_queue = None
        BaseTwistedTestCase.tearDown(self)


class ConnectionTestCase(FactoryBaseTestCase):
    """Test TCP/SSL connection mechanism for ActionQueue."""

    def assert_connection_state_reset(self):
        """Test connection state is properly reset."""
        self.assertTrue(self.action_queue.client is None)
        self.assertTrue(self.action_queue.connector is None)
        self.assertEquals(False, self.action_queue.connect_in_progress)

    def test_init(self):
        """Test connection init state."""
        self.assert_connection_state_reset()

    @defer.inlineCallbacks
    def test_connect_if_already_connected(self):
        """Test that double connections are avoided."""
        yield self._connect_factory()

        assert self.action_queue.connector is not None
        assert self.action_queue.connect_in_progress == True
        # double connect, it returns None instead of a Deferred
        result = self.action_queue.connect()
        self.assertTrue(result is None, 'not connecting again')

        yield self._disconnect_factory()

    @defer.inlineCallbacks
    def test_disconnect_if_connected(self):
        """self.action_queue.connector.disconnect was called."""
        yield self._connect_factory()

        self.action_queue.event_queue.events = [] # cleanup events
        assert self.action_queue.connector.state == 'connected'
        self.action_queue.disconnect()

        self.assert_connection_state_reset()
        self.assertEquals([], self.action_queue.event_queue.events)

        yield self._disconnect_factory()

    @defer.inlineCallbacks
    def test_clientConnectionFailed(self):
        """Test clientConnectionFailed.

        The connection will not be completed since the server will be down.
        So, self.action_queue.connector will never leave the 'connecting' state.
        When interrupting the connection attempt, twisted automatically calls
        self.action_queue.clientConnectionFailed.

        """
        self.action_queue.event_queue.events = []
        orig = self.action_queue.clientConnectionFailed

        d = defer.Deferred()
        def faked_clientConnectionFailed(connector, reason):
            """Receive connection failed and check."""
            self.action_queue.deferred.errback = lambda _: None
            orig(connector, reason)
            self.assert_connection_state_reset()
            self.assertEquals([('SYS_CONNECTION_FAILED', (), {})],
                              self.action_queue.event_queue.events)
            self.action_queue.clientConnectionFailed = orig
            d.callback(True)

        self.action_queue.clientConnectionFailed = faked_clientConnectionFailed
        # factory will never finish the connection, server was never started
        self.action_queue.connect()
        # stopConnecting() will be called since the connection is in progress
        assert self.action_queue.connector.state == 'connecting'
        self.action_queue.connector.disconnect()

        yield d

    @defer.inlineCallbacks
    def test_clientConnectionLost(self):
        """Test clientConnectionLost

        The connection will be completed successfully.
        So, self.action_queue.connector will be in the 'connected' state.
        When disconnecting the connector, twisted automatically calls
        self.action_queue.clientConnectionLost.

        """
        yield self._connect_factory()

        self.action_queue.event_queue.events = []
        orig = self.action_queue.clientConnectionLost

        d = defer.Deferred()
        def faked_clientConnectionLost(connector, reason):
            """Receive connection lost and check."""
            orig(connector, reason)
            self.assert_connection_state_reset()
            self.assertEquals([('SYS_CONNECTION_LOST', (), {})],
                              self.action_queue.event_queue.events)
            self.action_queue.clientConnectionLost = orig
            d.callback(True)

        self.action_queue.clientConnectionLost = faked_clientConnectionLost
        # loseConnection() will be called since the connection was completed
        assert self.action_queue.connector.state == 'connected'
        self.action_queue.connector.disconnect()
        yield d

        yield self._disconnect_factory()

    @defer.inlineCallbacks
    def test_server_disconnect(self):
        """Test factory's connection when the server goes down."""

        yield self._connect_factory()

        self.action_queue.event_queue.events = []
        orig = self.action_queue.clientConnectionLost

        d = defer.Deferred()
        def faked_connectionLost(*args, **kwargs):
            """Receive connection lost and check."""
            orig(*args, **kwargs)
            self.assert_connection_state_reset()
            self.assertEquals([('SYS_CONNECTION_LOST', (), {})],
                              self.action_queue.event_queue.events)
            self.action_queue.clientConnectionLost = orig
            d.callback(True)

        self.action_queue.clientConnectionLost = faked_connectionLost
        # simulate a server failure!
        yield self.server_transport.loseConnection()
        yield d
        yield self._disconnect_factory()

    def test_buildProtocol(self):
        """Test buildProtocol."""
        protocol = self.action_queue.buildProtocol(addr=None)
        self.assertTrue(protocol is self.action_queue.client)
        self.assertTrue(self.action_queue is self.action_queue.client.factory)

        # callbacks are connected
        # pylint: disable-msg=W0212
        self.assertEquals(self.action_queue.client._node_state_callback,
                          self.action_queue._node_state_callback)
        self.assertEquals(self.action_queue.client._share_change_callback,
                          self.action_queue._share_change_callback)
        self.assertEquals(self.action_queue.client._share_answer_callback,
                          self.action_queue._share_answer_callback)
        self.assertEquals(self.action_queue.client._free_space_callback,
                          self.action_queue._free_space_callback)
        self.assertEquals(self.action_queue.client._account_info_callback,
                          self.action_queue._account_info_callback)
        self.assertEquals(self.action_queue.client._volume_created_callback,
                          self.action_queue._volume_created_callback)
        self.assertEquals(self.action_queue.client._volume_deleted_callback,
                          self.action_queue._volume_deleted_callback)

    @defer.inlineCallbacks
    def test_connector_gets_assigned_on_connect(self):
        """Test factory's connector gets assigned on connect."""
        yield self._connect_factory()

        self.assertTrue(self.action_queue.connector is not None)

        yield self._disconnect_factory()

    @defer.inlineCallbacks
    def test_cleanup_doesnt_disconnect(self):
        """cleanup() doesn't disconnect the factory."""
        yield self._connect_factory()

        self.action_queue.cleanup()
        self.assertTrue(self.action_queue.connector is not None)
        self.assertEquals(self.action_queue.connector.state, 'connected')

        yield self._disconnect_factory()


class NetworkmanagerTestCase(BasicTestCase, FactoryBaseTestCase):
    """Base test case generating a connected factory."""

    timeout = 15

    def fake_answer(self, answer):
        """Push an event faking a server answer."""
        return (lambda *_: self.action_queue.event_queue.push(answer))

    def setUp(self):
        """Init."""
        BasicTestCase.setUp(self)

        self.action_queue.local_rescan = \
            self.fake_answer('SYS_LOCAL_RESCAN_DONE')
        self.action_queue.check_version = \
            self.fake_answer('SYS_PROTOCOL_VERSION_OK')
        self.action_queue.set_capabilities = \
            self.fake_answer('SYS_SET_CAPABILITIES_OK')
        self.action_queue.authenticate = \
            self.fake_answer('SYS_AUTH_OK')
        self.action_queue.server_rescan = \
            self.fake_answer('SYS_SERVER_RESCAN_DONE')

        self.main.start()

    def tearDown(self):
        """Clean up."""
        BasicTestCase.tearDown(self)

    @defer.inlineCallbacks
    def test_wrong_disconnect(self):
        """Test factory's connection when SYS_NET_DISCONNECTED."""

        d1 = self.main.wait_for('SYS_CONNECTION_MADE')
        d2 = self.main.wait_for('SYS_CONNECTION_LOST')

        self.server = self._start_sample_webserver()
        self.action_queue.event_queue.push('SYS_USER_CONNECT',
                                            access_token='ble')
        yield d1

        self.action_queue.event_queue.push('SYS_NET_DISCONNECTED')
        yield d2

    @defer.inlineCallbacks
    def test_disconnect_twice(self):
        """Test connection when SYS_NET_DISCONNECTED is received twice."""

        d1 = self.main.wait_for('SYS_CONNECTION_MADE')
        d2 = self.main.wait_for('SYS_CONNECTION_LOST')

        self.server = self._start_sample_webserver()

        self.action_queue.event_queue.push('SYS_USER_CONNECT',
                                           access_token='ble')
        yield d1

        self.action_queue.event_queue.push('SYS_NET_DISCONNECTED')
        yield d2

        self.action_queue.event_queue.events = []
        self.action_queue.event_queue.push('SYS_NET_DISCONNECTED')
        self.assertEquals([('SYS_NET_DISCONNECTED', (), {})],
                          self.action_queue.event_queue.events,
                       'No new events after a missplaced SYS_NET_DISCONNECTED')


    @defer.inlineCallbacks
    def test_net_connected_if_already_connected(self):
        """Test connection when SYS_NET_CONNECTED is received twice."""

        d1 = self.main.wait_for('SYS_CONNECTION_MADE')

        self.server = self._start_sample_webserver()

        self.action_queue.event_queue.push('SYS_USER_CONNECT',
                                           access_token='ble')
        yield d1

        self.action_queue.event_queue.events = []
        self.action_queue.event_queue.push('SYS_NET_CONNECTED')
        self.assertEquals([('SYS_NET_CONNECTED', (), {})],
                          self.action_queue.event_queue.events,
                          'No new events after a missplaced SYS_NET_CONNECTED')

    @defer.inlineCallbacks
    def test_messy_mix(self):
        """Test connection when a messy mix of events is received."""
        orig_waiting = states.MAX_WAITING
        states.MAX_WAITING = 1

        self.action_queue.event_queue.events = []
        self.server = self._start_sample_webserver()

        conn_made = self.main.wait_for('SYS_CONNECTION_MADE')
        self.action_queue.event_queue.push('SYS_USER_CONNECT',
                                           access_token='ble')
        yield conn_made

        events = ['SYS_NET_CONNECTED', 'SYS_NET_DISCONNECTED',
                  'SYS_NET_CONNECTED', 'SYS_NET_CONNECTED',
                  'SYS_NET_DISCONNECTED', 'SYS_NET_DISCONNECTED',
                  'SYS_NET_CONNECTED']

        for i in events:
            self.action_queue.event_queue.push(i)

        yield self.main.wait_for_nirvana()

        expected = ['SYS_NET_CONNECTED', # from the DBus fake NetworkManager
                    'SYS_USER_CONNECT', 'SYS_CONNECTION_MADE',
                    'SYS_NET_CONNECTED', 'SYS_NET_DISCONNECTED',
                    'SYS_CONNECTION_LOST', 'SYS_CONNECTION_RETRY',
                    'SYS_NET_CONNECTED', 'SYS_NET_CONNECTED',
                    'SYS_CONNECTION_MADE', 'SYS_NET_DISCONNECTED',
                    'SYS_NET_DISCONNECTED']

        avoid = ('SYS_STATE_CHANGED', 'SYS_LOCAL_RESCAN_DONE',
                 'SYS_PROTOCOL_VERSION_OK', 'SYS_SET_CAPABILITIES_OK',
                 'SYS_AUTH_OK', 'SYS_SERVER_RESCAN_DONE')
        actual = [event for (event, args, kwargs) in
                  self.action_queue.event_queue.events
                  if event not in avoid]
        self.assertEquals(sorted(expected), sorted(actual))

        states.MAX_WAITING = orig_waiting


class ConnectedBaseTestCase(FactoryBaseTestCase):
    """Base test case generating a connected factory."""

    @defer.inlineCallbacks
    def setUp(self):
        """Init."""
        FactoryBaseTestCase.setUp(self)
        yield self._connect_factory()
        assert self.action_queue.connector.state == 'connected'

    @defer.inlineCallbacks
    def tearDown(self):
        """Clean up."""
        yield self._disconnect_factory()
        FactoryBaseTestCase.tearDown(self)


class VolumeManagementTestCase(ConnectedBaseTestCase):
    """Test Volume managemenr for ActionQueue."""

    def test_volume_created_push_event(self):
        """Volume created callback push proper event."""
        volume = FakedVolume()
        self.action_queue._volume_created_callback(volume)
        self.assertEquals([('SV_VOLUME_CREATED', (), {'volume': volume})],
                          self.action_queue.event_queue.events)

    def test_volume_deleted_push_event(self):
        """Volume deleted callback push proper event."""
        volume_id = VOLUME
        self.action_queue._volume_deleted_callback(volume_id)
        self.assertEquals([('SV_VOLUME_DELETED', (), {'volume_id': volume_id})],
                          self.action_queue.event_queue.events)

    def test_valid_events(self):
        """Volume events are valid in EventQueue."""
        new_events = ('SV_VOLUME_CREATED', 'SV_VOLUME_DELETED',
                      'AQ_CREATE_UDF_OK', 'AQ_CREATE_UDF_ERROR',
                      'AQ_LIST_VOLUMES', 'AQ_LIST_VOLUMES_ERROR',
                      'AQ_DELETE_VOLUME_OK', 'AQ_DELETE_VOLUME_ERROR')
        for event in new_events:
            self.assertTrue(event in EVENTS)

        self.assertEquals(('volume',), EVENTS['SV_VOLUME_CREATED'])
        self.assertEquals(('volume_id',), EVENTS['SV_VOLUME_DELETED'])
        self.assertEquals(('volume_id', 'node_id', 'marker'),
                          EVENTS['AQ_CREATE_UDF_OK'])
        self.assertEquals(('error', 'marker'), EVENTS['AQ_CREATE_UDF_ERROR'])
        self.assertEquals(('volumes',), EVENTS['AQ_LIST_VOLUMES'])
        self.assertEquals(('error',), EVENTS['AQ_LIST_VOLUMES_ERROR'])
        self.assertEquals(('volume_id',), EVENTS['AQ_DELETE_VOLUME_OK'])
        self.assertEquals(('volume_id', 'error',), EVENTS['AQ_DELETE_VOLUME_ERROR'])

    def test_create_udf(self):
        """Test volume creation."""
        path = PATH
        name = NAME
        res = self.action_queue.create_udf(path, name, marker=None)
        self.assertTrue(res is None) # this is what start returns

    def test_list_volumes(self):
        """Test volume listing."""
        res = self.action_queue.list_volumes()
        self.assertTrue(res is None) # this is what start returns

    def test_delete_volume(self):
        """Test volume deletion."""
        volume_id = VOLUME
        res = self.action_queue.delete_volume(volume_id)
        self.assertTrue(res is None) # this is what start returns


class CreateUDFTestCase(ConnectedBaseTestCase):
    """Test for CreateUDF ActionQueueCommand."""

    def setUp(self):
        """Init."""
        res = super(CreateUDFTestCase, self).setUp()

        request_queue = RequestQueue(name='foo', action_queue=self.action_queue)
        self.marker = VOLUME
        self.command = CreateUDF(request_queue, PATH, NAME, marker=self.marker)

        return res

    def test_is_action_queue_command(self):
        """Test proper inheritance."""
        self.assertTrue(isinstance(self.command, ActionQueueCommand))

    def test_init(self):
        """Test creation."""
        self.assertEquals(PATH, self.command.path)
        self.assertEquals(NAME, self.command.name)
        self.assertEquals(self.marker, self.command.marker)

    def test_run_returns_a_deferred(self):
        """Test a deferred is returned."""
        res = self.command._run()
        self.assertTrue(isinstance(res, defer.Deferred), 'deferred returned')

    def test_run_calls_protocol(self):
        """Test protocol's create_udf is called."""
        original = self.command.action_queue.client.create_udf
        self.called = False

        def check(path, name):
            """Take control over client's feature."""
            self.called = True
            self.assertEquals(PATH, path)
            self.assertEquals(NAME, name)

        self.command.action_queue.client.create_udf = check

        self.command._run()

        self.assertTrue(self.called, 'command was called')

        self.command.action_queue.client.create_udf = original

    def test_handle_success_push_event(self):
        """Test AQ_CREATE_UDF_OK is pushed on success."""
        request = client.CreateUDF(self.action_queue.client, PATH, NAME)
        request.volume_id = VOLUME
        request.node_id = NODE
        res = self.command.handle_success(success=request)
        events = [('AQ_CREATE_UDF_OK', (), {'volume_id': VOLUME,
                                            'node_id': NODE,
                                            'marker': self.marker})]
        self.assertEquals(events, self.command.action_queue.event_queue.events)
        self.assertEquals(request, res)

    def test_handle_failure_push_event(self):
        """Test AQ_CREATE_UDF_ERROR is pushed on failure."""
        msg = 'Something went wrong'
        failure = Failure(DefaultException(msg))
        res = self.command.handle_failure(failure=failure)
        events = [('AQ_CREATE_UDF_ERROR', (),
                    {'error': msg, 'marker': self.marker})]
        self.assertEquals(events, self.command.action_queue.event_queue.events)
        self.assertTrue(res is None)


class ActionQueueCommandErrors(ConnectedBaseTestCase):
    """Test the error handling in ActionQueueCommand."""

    def setUp(self):
        res = super(ActionQueueCommandErrors, self).setUp()

        self.deferred = defer.Deferred()

        class MyLogger(object):
            """Fake logger that just stores error and warning calls."""
            def __init__(self):
                self.logged = None

            def error(self, *a):
                """Mark that this method was called."""
                self.logged = "error"

            def warn(self, *a):
                """Mark that this method was called."""
                self.logged = "warn"

            def debug(self, *a):
                """Nothing."""

        class MyCommand(ActionQueueCommand):
            """Inherit ACQ to provide a retry signaller and a custom log."""
            # class-closure, cannot use self, pylint: disable-msg=E0213
            def __init__(innerself, request_queue):
                super(MyCommand, innerself).__init__(request_queue)
                innerself.log = MyLogger()

            def retry(innerself):
                """Signal the retry."""
                self.deferred.callback(True)

        self.rq = RequestQueue(name='foo', action_queue=self.action_queue)
        self.command = MyCommand(self.rq)
        return res

    def test_suppressed_yes_knownerrors(self):
        """Check that the log is in warning for the known errors."""
        def send_failure_and_check(errnum, exception_class):
            """Send the failure."""
            # prepare what to send
            protocol_msg = protocol_pb2.Message()
            protocol_msg.type = protocol_pb2.Message.ERROR
            protocol_msg.error.type = errnum
            err = exception_class("request", protocol_msg)

            # set up and test
            self.command.log.logged = None
            self.command.end_errback(failure=Failure(err))
            self.assertEqual(self.command.log.logged, "warn",
                             "Bad log in exception %s" % (exception_class,))

        known_errors = [x for x in errors._error_mapping.items()
                        if x[1] != errors.InternalError]
        for errnum, exception_class in known_errors:
            send_failure_and_check(errnum, exception_class)

    def test_suppressed_no_internalerror(self):
        """Check that the log is in error for InternalError."""
        # prepare what to send
        protocol_msg = protocol_pb2.Message()
        protocol_msg.type = protocol_pb2.Message.ERROR
        protocol_msg.error.type = protocol_pb2.Error.INTERNAL_ERROR
        err = errors.InternalError("request", protocol_msg)

        # set up and test
        self.command.end_errback(failure=Failure(err))
        self.assertEqual(self.command.log.logged, "error")

    def test_suppressed_yes_cancelled(self):
        """Check that the log is in warning for Cancelled."""
        err = errors.RequestCancelledError("CANCELLED")
        self.command.end_errback(failure=Failure(err))
        self.assertEqual(self.command.log.logged, "warn")

    def test_suppressed_yes_and_retry_when_connectiondone(self):
        """Check that the log is in warning and retries for ConnectionDone."""
        self.command.running = True
        err = twisted_error.ConnectionDone()
        self.command.end_errback(failure=Failure(err))
        self.assertEqual(self.command.log.logged, "warn")
        return self.deferred

    def test_retry_connectionlost(self):
        """Check that it retries when ConnectionLost."""
        self.command.running = True
        err = twisted_error.ConnectionLost()
        self.command.end_errback(failure=Failure(err))
        return self.deferred

    def test_retry_tryagain(self):
        """Check that it retries when TryAgain."""
        # prepare what to send
        self.command.running = True
        protocol_msg = protocol_pb2.Message()
        protocol_msg.type = protocol_pb2.Message.ERROR
        protocol_msg.error.type = protocol_pb2.Error.TRY_AGAIN
        err = errors.TryAgainError("request", protocol_msg)

        # set up and test
        self.command.end_errback(failure=Failure(err))
        return self.deferred


class ListVolumesTestCase(ConnectedBaseTestCase):
    """Test for ListVolumes ActionQueueCommand."""

    def setUp(self):
        """Init."""
        res = super(ListVolumesTestCase, self).setUp()

        request_queue = RequestQueue(name='foo', action_queue=self.action_queue)
        self.command = ListVolumes(request_queue)

        return res

    def test_is_action_queue_command(self):
        """Test proper inheritance."""
        self.assertTrue(isinstance(self.command, ActionQueueCommand))

    def test_run_returns_a_deferred(self):
        """Test a deferred is returned."""
        res = self.command._run()
        self.assertTrue(isinstance(res, defer.Deferred), 'deferred returned')

    def test_run_calls_protocol(self):
        """Test protocol's list_volumes is called."""
        original = self.command.action_queue.client.list_volumes
        self.called = False

        def check():
            """Take control over client's feature."""
            self.called = True

        self.command.action_queue.client.list_volumes = check

        self.command._run()

        self.assertTrue(self.called, 'command was called')

        self.command.action_queue.client.list_volumes = original

    def test_handle_success_push_event(self):
        """Test AQ_LIST_VOLUMES is pushed on success."""
        request = client.ListVolumes(self.action_queue.client)
        request.volumes = [FakedVolume(), FakedVolume()]
        res = self.command.handle_success(success=request)
        events = [('AQ_LIST_VOLUMES', (), {'volumes': request.volumes})]
        self.assertEquals(events, self.command.action_queue.event_queue.events)
        self.assertEquals(request, res)

    def test_handle_failure_push_event(self):
        """Test AQ_LIST_VOLUMES_ERROR is pushed on failure."""
        msg = 'Something went wrong'
        failure = Failure(DefaultException(msg))
        res = self.command.handle_failure(failure=failure)
        events = [('AQ_LIST_VOLUMES_ERROR', (), {'error': msg})]
        self.assertEquals(events, self.command.action_queue.event_queue.events)
        self.assertTrue(res is None)


class DeleteVolumeTestCase(ConnectedBaseTestCase):
    """Test for DeleteVolume ActionQueueCommand."""

    def setUp(self):
        """Init."""
        res = super(DeleteVolumeTestCase, self).setUp()

        request_queue = RequestQueue(name='foo', action_queue=self.action_queue)
        self.command = DeleteVolume(request_queue, VOLUME)

        return res

    def test_is_action_queue_command(self):
        """Test proper inheritance."""
        self.assertTrue(isinstance(self.command, ActionQueueCommand))

    def test_init(self):
        """Test creation."""
        self.assertEquals(VOLUME, self.command.volume_id)

    def test_run_returns_a_deferred(self):
        """Test a deferred is returned."""
        res = self.command._run()
        self.assertTrue(isinstance(res, defer.Deferred), 'deferred returned')

    def test_run_calls_protocol(self):
        """Test protocol's delete_volume is called."""
        original = self.command.action_queue.client.delete_volume
        self.called = False

        def check(volume_id):
            """Take control over client's feature."""
            self.called = True
            self.assertEquals(VOLUME, volume_id)

        self.command.action_queue.client.delete_volume = check

        self.command._run()

        self.assertTrue(self.called, 'command was called')

        self.command.action_queue.client.delete_volume = original

    def test_handle_success_push_event(self):
        """Test AQ_DELETE_VOLUME_OK is pushed on success."""
        request = client.DeleteVolume(self.action_queue.client, volume_id=VOLUME)
        res = self.command.handle_success(success=request)
        events = [('AQ_DELETE_VOLUME_OK', (), {'volume_id': VOLUME})]
        self.assertEquals(events, self.command.action_queue.event_queue.events)
        self.assertEquals(request, res)

    def test_handle_failure_push_event(self):
        """Test AQ_DELETE_VOLUME_ERROR is pushed on failure."""
        msg = 'Something went wrong'
        failure = Failure(DefaultException(msg))
        res = self.command.handle_failure(failure=failure)
        events = [('AQ_DELETE_VOLUME_ERROR', (), {'volume_id': VOLUME, 'error': msg})]
        self.assertEquals(events, self.command.action_queue.event_queue.events)
        self.assertTrue(res is None)


class FilterEventsTestCase(BasicTestCase):
    """Tests for event filtering when a volume is not of our interest."""

    def setUp(self):
        """Init."""
        BasicTestCase.setUp(self)
        self.vm = self.main.vm
        self.old_home = os.environ.get('HOME', None)
        os.environ['HOME'] = self.home

    def tearDown(self):
        """Clean up."""
        if self.old_home is None:
            os.environ.pop('HOME')
        else:
            os.environ['HOME'] = self.old_home

        BasicTestCase.tearDown(self)

    @defer.inlineCallbacks
    def test_SV_HASH_NEW_is_pushed_for_subscrined_volume(self):
        """SV_HASH_NEW is filtered when the volume is unsubscribed."""
        udf_id = 'udf_id'
        udf_volume = volumes.UDFVolume(udf_id, 'udf_node',
                                       u'~/ñoño'.encode("utf8"))
        path = self.vm._build_udf_path(udf_volume.suggested_path)
        udf = UDF.from_udf_volume(udf_volume, path)
        yield self.vm.add_udf(udf)
        yield self.vm.subscribe_udf(udf_id)
        assert self.vm.udfs[udf_id].subscribed
        self.action_queue.event_queue.events = [] # reset events

        kwargs = dict(share_id=udf_id, node_id=NODE, hash=None)
        self.action_queue._node_state_callback(**kwargs)
        self.assertEquals([('SV_HASH_NEW', (), kwargs)],
                           self.action_queue.event_queue.events)

    @defer.inlineCallbacks
    def test_SV_HASH_NEW_is_filtered_for_unsubscrined_volume(self):
        """SV_HASH_NEW is filtered when the volume is unsubscribed."""
        # build a VM and add it an UDF with subscribed to False
        udf_id = 'udf_id'
        udf_volume = volumes.UDFVolume(udf_id, 'udf_node',
                                       u'~/ñoño'.encode("utf8"))
        path = self.vm._build_udf_path(udf_volume.suggested_path)
        udf = UDF.from_udf_volume(udf_volume, path)
        yield self.vm.add_udf(udf)
        yield self.vm.unsubscribe_udf(udf_id)
        assert not self.vm.udfs[udf_id].subscribed
        self.action_queue.event_queue.events = [] # reset events

        self.action_queue._node_state_callback(share_id=udf_id,
                                               node_id=None, hash=None)
        self.assertEquals([], self.action_queue.event_queue.events)

    def test_SV_HASH_NEW_doesnt_fail_for_non_udf(self):
        """SV_HASH_NEW keeps working like before for non-udfs."""
        other_id = 'not in udfs'
        assert other_id not in self.vm.udfs
        self.action_queue.event_queue.events = [] # reset events

        kwargs = dict(share_id=other_id, node_id=NODE, hash=None)
        self.action_queue._node_state_callback(**kwargs)
        self.assertEquals([('SV_HASH_NEW', (), kwargs)],
                          self.action_queue.event_queue.events)


class ChangePublicAccessTests(ConnectedBaseTestCase):
    """Tests for the ChangePublicAccess ActionQueueCommand."""

    def setUp(self):
        super(ChangePublicAccessTests, self).setUp()
        request_queue = RequestQueue(name='foo', action_queue=self.action_queue)
        self.command = ChangePublicAccess(request_queue, VOLUME, NODE, True)

    def test_change_public_access(self):
        """Test the change_public_access method.."""
        res = self.action_queue.change_public_access(VOLUME, NODE, True)
        self.assertTrue(res is None) # this is what start returns

    def test_is_action_queue_command(self):
        """Test proper inheritance."""
        self.assertTrue(isinstance(self.command, ActionQueueCommand))

    def test_init(self):
        """Test creation."""
        self.assertEquals(VOLUME, self.command.share_id)
        self.assertEquals(NODE, self.command.node_id)
        self.assertEquals(True, self.command.is_public)

    def test_run_defers_work_to_thread(self):
        """Test that work is deferred to a thread."""
        original = threads.deferToThread
        self.called = False

        def check(function):
            self.called = True
            self.assertEquals(
                self.command._change_public_access_http, function)
            return defer.Deferred()

        threads.deferToThread = check
        try:
            res = self.command._run()
        finally:
            threads.deferToThread = original

        self.assertTrue(isinstance(res, defer.Deferred), 'deferred returned')
        self.assertTrue(self.called, "deferToThread was called")

    def test_change_public_access_http(self):
        """Test the blocking portion of the command."""
        self.called = False
        def check(request):
            self.called = True
            url = 'https://one.ubuntu.com/files/api/set_public/%s:%s' % (
                base64.urlsafe_b64encode(VOLUME.bytes).strip("="),
                base64.urlsafe_b64encode(NODE.bytes).strip("="))
            self.assertEqual(url, request.get_full_url())
            self.assertEqual("is_public=True", request.get_data())
            return StringIO(
                '{"is_public": true, "public_url": "http://example.com"}')

        from ubuntuone.syncdaemon import action_queue
        action_queue.urlopen = check
        try:
            res = self.command._change_public_access_http()
        finally:
            action_queue.urlopen = urllib2.urlopen

        self.assertEqual(
            {'is_public': True, 'public_url': 'http://example.com'}, res)

    def test_handle_success_push_event(self):
        """Test AQ_CHANGE_PUBLIC_ACCESS_OK is pushed on success."""
        response = {'is_public': True, 'public_url': 'http://example.com'}
        res = self.command.handle_success(success=response)
        events = [('AQ_CHANGE_PUBLIC_ACCESS_OK', (),
                   {'share_id': VOLUME, 'node_id': NODE, 'is_public': True,
                    'public_url': 'http://example.com'})]
        self.assertEquals(events, self.command.action_queue.event_queue.events)
        self.assertEquals(response, res)

    def test_handle_failure_push_event(self):
        """Test AQ_CHANGE_PUBLIC_ACCESS_ERROR is pushed on failure."""
        msg = 'Something went wrong'
        failure = Failure(urllib2.HTTPError(
                "http://example.com", 500, "Error", [], StringIO(msg)))
        res = self.command.handle_failure(failure=failure)
        events = [('AQ_CHANGE_PUBLIC_ACCESS_ERROR', (),
                   {'share_id': VOLUME, 'node_id': NODE, 'error': msg})]
        self.assertEquals(events, self.command.action_queue.event_queue.events)
        self.assertTrue(res is None)


class GetPublicFilesTestCase(ConnectedBaseTestCase):
    """Tests for GetPublicFiles ActionQueueCommand."""

    def setUp(self):
        super(GetPublicFilesTestCase, self).setUp()
        request_queue = RequestQueue(name='foo', action_queue=self.action_queue)
        self.command = GetPublicFiles(request_queue)

    def test_init(self):
        """Test __init__ method."""
        default_url = 'https://one.ubuntu.com/files/api/public_files'
        request_queue = RequestQueue(name='foo', action_queue=self.action_queue)
        command = GetPublicFiles(request_queue)
        self.assertEquals(command._url, default_url)
        custom_url = 'http://example.com:1234/files/api/public_files'
        command_2 = GetPublicFiles(request_queue, base_url='http://example.com:1234')
        self.assertEquals(command_2._url, custom_url)

    def test_change_public_access(self):
        """Test the get_public_files method.."""
        res = self.action_queue.get_public_files()
        self.assertTrue(res is None) # this is what start returns

    def test_is_action_queue_command(self):
        """Test proper inheritance."""
        self.assertTrue(isinstance(self.command, ActionQueueCommand))

    def test_run_defers_work_to_thread(self):
        """Test that work is deferred to a thread."""
        original = threads.deferToThread
        self.called = False

        def check(function):
            self.called = True
            self.assertEquals(
                self.command._get_public_files_http, function)
            return defer.Deferred()

        threads.deferToThread = check
        try:
            res = self.command._run()
        finally:
            threads.deferToThread = original

        self.assertTrue(isinstance(res, defer.Deferred), 'deferred returned')
        self.assertTrue(self.called, "deferToThread was called")

    def test_get_public_files_http(self):
        """Test the blocking portion of the command."""
        self.called = False
        node_id = uuid.uuid4()
        nodekey = '%s' % (base64.urlsafe_b64encode(node_id.bytes).strip("="))
        node_id_2 = uuid.uuid4()
        nodekey_2 = '%s' % (base64.urlsafe_b64encode(node_id_2.bytes).strip("="))
        volume_id = uuid.uuid4()
        def check(request):
            self.called = True
            url = 'https://one.ubuntu.com/files/api/public_files'
            self.assertEqual(url, request.get_full_url())
            return StringIO(
                '[{"nodekey": "%s", "volume_id": null,"public_url": '
                '"http://example.com"}, '
                '{"nodekey": "%s", "volume_id": "%s", "public_url": '
                '"http://example.com"}]' % (nodekey, nodekey_2, volume_id))

        from ubuntuone.syncdaemon import action_queue
        action_queue.urlopen = check
        try:
            res = self.command._get_public_files_http()
        finally:
            action_queue.urlopen = urllib2.urlopen
        self.assertEqual([{'node_id': str(node_id), 'volume_id': '',
                          'public_url': 'http://example.com'},
                          {'node_id': str(node_id_2),
                           'volume_id': str(volume_id),
                           'public_url': 'http://example.com'}], res)

    def test_handle_success_push_event(self):
        """Test AQ_PUBLIC_FILES_LIST_OK is pushed on success."""
        response = [{'node_id': uuid.uuid4(), 'volume_id':None,
                    'public_url': 'http://example.com'}]
        res = self.command.handle_success(success=response)
        events = [('AQ_PUBLIC_FILES_LIST_OK', (), {'public_files': response,})]
        self.assertEquals(events, self.command.action_queue.event_queue.events)
        self.assertEquals(response, res)

    def test_handle_failure_push_event(self):
        """Test AQ_PUBLIC_FILES_LIST_ERROR is pushed on failure."""
        msg = 'Something went wrong'
        failure = Failure(urllib2.HTTPError(
                "http://example.com", 500, "Error", [], StringIO(msg)))
        res = self.command.handle_failure(failure=failure)
        events = [('AQ_PUBLIC_FILES_LIST_ERROR', (), {'error': msg})]
        self.assertEquals(events, self.command.action_queue.event_queue.events)
        self.assertTrue(res is None)


class ListDirTestCase(ConnectedBaseTestCase):
    """Test for ListDir ActionQueueCommand."""

    def setUp(self):
        """Init."""
        res = super(ListDirTestCase, self).setUp()

        request_queue = RequestQueue(name='FOO', action_queue=self.action_queue)
        self.command = ListDir(request_queue, share_id='a_share_id',
                               node_id='a_node_id', server_hash='a_server_hash',
                               fileobj_factory=lambda: None)
        self.command.start_unqueued() # create the logger

        return res

    def test_failure_with_CANCELLED(self):
        """AQ_DOWNLOAD_CANCELLED is pushed."""
        err = errors.RequestCancelledError("CANCELLED")
        res = self.command.handle_failure(failure=Failure(err))
        kwargs = dict(share_id='a_share_id', node_id='a_node_id',
                      server_hash='a_server_hash')
        events = [('AQ_DOWNLOAD_CANCELLED', (), kwargs)]
        self.assertEquals(events, self.command.action_queue.event_queue.events)
        self.assertTrue(res is None)

    def test_failure_without_CANCELLED(self):
        """AQ_DOWNLOAD_ERROR with proper error is pushed."""
        msg = 'NOT_CANCELLED'
        failure = Failure(DefaultException(msg))
        res = self.command.handle_failure(failure=failure)
        kwargs = dict(share_id='a_share_id', node_id='a_node_id',
                      server_hash='a_server_hash', error=msg)
        events = [('AQ_DOWNLOAD_ERROR', (), kwargs)]
        self.assertEquals(events, self.command.action_queue.event_queue.events)
        self.assertTrue(res is None)

    def test_failure_with_DOES_NOT_EXIST(self):
        """AQ_DOWNLOAD_DOES_NOT_EXIST is pushed."""
        protocol_msg = protocol_pb2.Message()
        protocol_msg.type = protocol_pb2.Message.ERROR
        protocol_msg.error.type = protocol_pb2.Error.DOES_NOT_EXIST
        err = errors.DoesNotExistError("request", protocol_msg)
        res = self.command.handle_failure(failure=Failure(err))
        kwargs = dict(share_id='a_share_id', node_id='a_node_id')
        events = [('AQ_DOWNLOAD_DOES_NOT_EXIST', (), kwargs)]
        self.assertEquals(events, self.command.action_queue.event_queue.events)
        self.assertTrue(res is None)

    def test_failure_without_DOES_NOT_EXIST(self):
        """AQ_DOWNLOAD_ERROR with proper error is pushed."""
        msg = 'DOES_EXIST'
        failure = Failure(DefaultException(msg))
        res = self.command.handle_failure(failure=failure)
        kwargs = dict(share_id='a_share_id', node_id='a_node_id',
                      server_hash='a_server_hash', error=msg)
        events = [('AQ_DOWNLOAD_ERROR', (), kwargs)]
        self.assertEquals(events, self.command.action_queue.event_queue.events)
        self.assertTrue(res is None)

    def test_AQ_DOWNLOAD_DOES_NOT_EXIST_is_valid_event(self):
        """AQ_DOWNLOAD_DOES_NOT_EXIST is a valdi event."""
        event = 'AQ_DOWNLOAD_DOES_NOT_EXIST'
        self.assertTrue(event in EVENTS)
        self.assertEquals(('share_id', 'node_id'), EVENTS[event])


class UploadTestCase(ConnectedBaseTestCase):
    """Test for Upload ActionQueueCommand."""

    def setUp(self):
        """Init."""
        super(UploadTestCase, self).setUp()

        self.rq = request_queue = RequestQueue(name='FOO',
                                               action_queue=self.action_queue)
        self.command = Upload(request_queue, share_id='a_share_id',
                              node_id='a_node_id', previous_hash='prev_hash',
                              hash='yadda', crc32=0, size=0,
                              fileobj_factory=lambda: None,
                              tempfile_factory=lambda: None)
        self.command.start_unqueued() # create the logger

    def test_upload_in_progress(self):
        """Test Upload retries on UploadInProgress."""
        # monkeypatching is not allowed, let's do inheritance
        d = defer.Deferred()
        class MyUpload(Upload):
            """Just to redefine retry."""
            def retry(self):
                """Detect retry was called."""
                d.callback(True)

        # set up the command
        command = MyUpload(self.rq, 'share', 'bode', 'prvhash', 'currhash',
                           0, 0, lambda: None, lambda: None)
        command.start_unqueued() # create log in the instance
        command.running = True

        # send the failure
        protocol_msg = protocol_pb2.Message()
        protocol_msg.type = protocol_pb2.Message.ERROR
        protocol_msg.error.type = protocol_pb2.Error.UPLOAD_IN_PROGRESS
        err = errors.UploadInProgressError("request", protocol_msg)
        command.end_errback(failure=Failure(err))
        return d

    def test_handle_success_push_event(self):
        """Test AQ_UPLOAD_FINISHED is pushed on success."""
        self.command.handle_success(None)
        kwargs = dict(share_id='a_share_id', node_id='a_node_id', hash='yadda')
        events = [('AQ_UPLOAD_FINISHED', (), kwargs)]
        self.assertEquals(events, self.command.action_queue.event_queue.events)

    def test_handle_failure_push_event(self):
        """Test AQ_UPLOAD_ERROR is pushed on failure."""
        msg = 'Something went wrong'
        failure = Failure(DefaultException(msg))
        res = self.command.handle_failure(failure=failure)
        kwargs = dict(share_id='a_share_id', node_id='a_node_id',
                      hash='yadda', error=msg)
        events = [('AQ_UPLOAD_ERROR', (), kwargs)]
        self.assertEquals(events, self.command.action_queue.event_queue.events)
        self.assertTrue(res is None)

    def test_handle_failure_removes_temp_file(self):
        """Test temp file is removed on failure."""
        class TempFile(object): pass
        self.command.tempfile = TempFile()
        self.command.tempfile.name = os.path.join(self.tmpdir, 'remove-me.zip')
        open(self.command.tempfile.name, 'w').close()
        assert os.path.exists(self.command.tempfile.name)

        msg = 'Something went wrong'
        failure = Failure(DefaultException(msg))
        self.command.handle_failure(failure=failure)

        self.assertFalse(os.path.exists(self.command.tempfile.name))

    @defer.inlineCallbacks
    def test_compress_failed_pushes_upload_error(self):
        msg = 'Zip can not be accomplished.'
        error = DefaultException(msg)
        self.action_queue.zip_queue.zip = lambda upload: defer.fail(error)
        yield self.command.run()
        kwargs = dict(share_id='a_share_id', node_id='a_node_id',
                      hash='yadda', error=msg)
        events = [('AQ_UPLOAD_ERROR', (), kwargs)]
        self.assertEquals(events, self.command.action_queue.event_queue.events)


class CreateShareTestCase(ConnectedBaseTestCase):
    """Test for CreateShare ActionQueueCommand."""

    @defer.inlineCallbacks
    def setUp(self):
        """Init."""
        yield super(CreateShareTestCase, self).setUp()
        self.request_queue = RequestQueue(name='foo', action_queue=self.action_queue)
        self.orig_create_share_http = CreateShare._create_share_http

    @defer.inlineCallbacks
    def tearDown(self):
        yield super(CreateShareTestCase, self).tearDown()
        CreateShare._create_share_http = self.orig_create_share_http

    @defer.inlineCallbacks
    def test_access_level_modify_http(self):
        """Test proper handling of the access level in the http case."""
        # replace _create_share_http with a fake, just to check the args
        d = defer.Deferred()
        def check_create_http(self, node_id, user, name, read_only, deferred):
            """Fire the deferred with the args."""
            d.callback((node_id, user, name, read_only))
            deferred.callback(None)
        CreateShare._create_share_http = check_create_http
        command = CreateShare(self.request_queue, 'node_id',
                                   'share_to@example.com', 'share_name',
                                   'Modify', 'marker')
        self.assertTrue(command.use_http, 'CreateShare should be in http mode')

        command._run()
        node_id, user, name, read_only = yield d
        self.assertEquals('node_id', node_id)
        self.assertEquals('share_to@example.com', user)
        self.assertEquals('share_name', name)
        self.assertFalse(read_only)

    @defer.inlineCallbacks
    def test_access_level_view_http(self):
        """Test proper handling of the access level in the http case."""
        # replace _create_share_http with a fake, just to check the args
        d = defer.Deferred()
        def check_create_http(self, node_id, user, name, read_only, deferred):
            """Fire the deferred with the args."""
            d.callback((node_id, user, name, read_only))
            deferred.callback(None)
        CreateShare._create_share_http = check_create_http
        command = CreateShare(self.request_queue, 'node_id',
                                   'share_to@example.com', 'share_name',
                                   'View', 'marker')
        self.assertTrue(command.use_http, 'CreateShare should be in http mode')
        command._run()
        node_id, user, name, read_only = yield d
        self.assertEquals('node_id', node_id)
        self.assertEquals('share_to@example.com', user)
        self.assertEquals('share_name', name)
        self.assertTrue(read_only)


class RequestQueueManager(FactoryBaseTestCase):
    """Test how RequestQueue manages the queues."""

    def setUp(self):
        FactoryBaseTestCase.setUp(self)

        self.queue = self.action_queue.meta_queue
        self.cmd = FakeCommand()

    def _events(self):
        """Helper method to see only the events."""
        return [x[0] for x in self.action_queue.event_queue.events]

    def test_empty_gets_one(self):
        """Queue and get the event."""
        self.queue.queue(self.cmd)
        self.assertEqual(self._events(), ['SYS_META_QUEUE_WAITING'])

    def test_with_one_gets_second(self):
        """Queue a second one, no event received."""
        self.queue.queue(self.cmd)
        self.queue.queue(self.cmd)
        # only get the event of the first one
        self.assertEqual(self._events(), ['SYS_META_QUEUE_WAITING'])

    @defer.inlineCallbacks
    def test_empty_run(self):
        """No event received when running nothing."""
        yield self.queue.run()
        self.assertEqual(self._events(), [])

    @defer.inlineCallbacks
    def test_with_one_run(self):
        """Run will execute the command."""
        self.queue.queue(self.cmd)
        yield self.queue.run()
        self.assertEqual(self._events(), ['SYS_META_QUEUE_WAITING',
                                          'SYS_META_QUEUE_DONE'])

    @defer.inlineCallbacks
    def test_with_two_run(self):
        """Run will execute both commands."""
        # first queuing, get the event
        self.queue.queue(self.cmd)
        self.assertEqual(self._events(), ['SYS_META_QUEUE_WAITING'])

        # second queuing, don't get the event
        self.queue.queue(self.cmd)
        self.assertEqual(self._events(), ['SYS_META_QUEUE_WAITING'])

        # first run, will get the waiting for the second event
        yield self.queue.run()
        self.assertEqual(self._events(), ['SYS_META_QUEUE_WAITING',
                                          'SYS_META_QUEUE_WAITING'])

        # second run, now we're done
        yield self.queue.run()
        self.assertEqual(self._events(), ['SYS_META_QUEUE_WAITING',
                                          'SYS_META_QUEUE_WAITING',
                                          'SYS_META_QUEUE_DONE'])

    def test_len_empty(self):
        """Counter return that it's empty."""
        self.assertEqual(len(self.queue), 0)

    def test_len_with_one(self):
        """Counter return that it has one."""
        self.queue.queue(self.cmd)
        self.assertEqual(len(self.queue), 1)

    def test_len_with_two(self):
        """Counter return that it has two."""
        self.queue.queue(self.cmd)
        self.queue.queue(self.cmd)
        self.assertEqual(len(self.queue), 2)

    @defer.inlineCallbacks
    def test_len_run_decreases(self):
        """Counter behaviour when adding/running."""
        self.queue.queue(self.cmd)
        self.assertEqual(len(self.queue), 1)
        self.queue.queue(self.cmd)
        self.assertEqual(len(self.queue), 2)
        yield self.queue.run()
        self.assertEqual(len(self.queue), 1)
        yield self.queue.run()
        self.assertEqual(len(self.queue), 0)


class SimpleAQTestCase(BasicTestCase):
    """Simple tests for AQ API."""

    def test_aq_server_rescan(self):
        """Check the API of AQ.server_rescan."""
        self.main.start()
        d = defer.Deferred()
        def get_root(mdid):
            """Fake get_root."""
            d.callback(mdid)

        self.action_queue.client = DummyClass()
        self.action_queue.get_root = get_root
        self.action_queue.server_rescan('foo', lambda: list())
        def check(result):
            self.assertEquals('foo', result)
        d.addCallback(check)
        return d


class SpecificException(Exception):
    """The specific exception."""


class SillyClass(object):
    """Silly class that accepts the set of any attribute.

    We can't use object() directly, since its raises AttributeError.

    """


class ErrorHandlingTestCase(BasicTestCase):
    """Error handling tests for ActionQueue."""

    def setUp(self):
        """Init."""
        BasicTestCase.setUp(self)

        self.called = False
        self.action_queue.client = SillyClass()
        self.action_queue.deferred = defer.Deferred()
        self.patch(self.main, 'restart', lambda: None)

        self.main.start()

    def fail_please(self, an_exception):
        """Raise the given exception."""
        def inner(*args, **kwargs):
            """A request to the server that fails."""
            self.called = True
            return defer.fail(an_exception)
        return inner

    def succeed_please(self, result):
        """Return the given result."""
        def inner(*args, **kwargs):
            """A request to the server that succeeds."""
            self.called = True
            return defer.succeed(result)
        return inner

    def mock_caps(self, accepted):
        """Reply to query caps with False."""
        def gset_caps(caps):
            """get/set caps helper."""
            req = SillyClass()
            req.caps = caps
            req.accepted = accepted
            return defer.succeed(req)
        return gset_caps

    def test_valid_event(self):
        """SYS_SERVER_ERROR is valid in EventQueue."""
        event = 'SYS_SERVER_ERROR'
        self.assertTrue(event in EVENTS)
        self.assertEquals(('error',), EVENTS[event])

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_no_error(self):
        """_send_request_and_handle_errors is correct when no error."""

        event = 'SYS_SPECIFIC_OK'
        EVENTS[event] = () # add event to the global valid events list
        self.addCleanup(lambda: EVENTS.pop(event))

        result = object()
        request = self.succeed_please(result)
        kwargs = dict(request=request, request_error=SpecificException,
                      event_error='YADDA_YADDA', event_ok=event,
                      args=(1, 2), kwargs={})
        d = self.action_queue._send_request_and_handle_errors(**kwargs)
        actual_result = yield d

        self.assertTrue(self.called, 'the request was called')
        self.assertEqual(actual_result, result)
        self.assertEqual((event, (), {}),
                         self.action_queue.event_queue.events[-1])

        # assert over logging
        self.assertEqual(1, len(self.handler.records))
        self.assertIn(request.__name__, self.handler.records[0].message)
        self.assertIn('OK', self.handler.records[0].message)

        # assert internal deferred wasn't fired
        self.assertFalse(self.action_queue.deferred.called)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_with_no_event_ok(self):
        """_send_request_and_handle_errors does not push event if is None."""
        original_events = self.action_queue.event_queue.events[:]

        result = object()
        request = self.succeed_please(result)
        kwargs = dict(request=request, request_error=SpecificException,
                      event_error='YADDA_YADDA', event_ok=None)
        d = self.action_queue._send_request_and_handle_errors(**kwargs)
        actual_result = yield d

        self.assertTrue(self.called, 'the request was called')
        self.assertEqual(actual_result, result)
        self.assertEqual(original_events,
                         self.action_queue.event_queue.events)

        # assert over logging
        self.assertEqual(1, len(self.handler.records))
        self.assertIn(request.__name__, self.handler.records[0].message)
        self.assertIn('OK', self.handler.records[0].message)

        # assert internal deferred wasn't fired
        self.assertFalse(self.action_queue.deferred.called)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_valid_error(self):
        """_send_request_and_handle_errors is correct when expected error."""

        event = 'SYS_SPECIFIC_ERROR'
        EVENTS[event] = ('error',) # add event to the global valid events list
        self.addCleanup(lambda: EVENTS.pop(event))

        exc = SpecificException('The request failed! please be happy.')
        request = self.fail_please(exc)
        kwargs = dict(request=request, request_error=SpecificException,
                      event_error=event, event_ok='YADDA_YADDA')
        d = self.action_queue._send_request_and_handle_errors(**kwargs)
        yield d

        self.assertTrue(self.called, 'the request was called')
        self.assertEqual((event, (), {'error': str(exc)}),
                         self.action_queue.event_queue.events[-1])

        # assert over logging
        self.assertEqual(1, len(self.handler.records))
        self.assertIn(request.__name__, self.handler.records[0].message)
        self.assertIn(str(exc), self.handler.records[0].message)

        # assert internal deferred was fired
        self.assertTrue(self.action_queue.deferred.called)
        self.assertIsInstance(self.action_queue.deferred.result, Failure)
        self.assertEqual(self.action_queue.deferred.result.value, exc)

    @defer.inlineCallbacks
    def assert_send_request_and_handle_errors_on_server_error(self, serr):
        """_send_request_and_handle_errors is correct when server error."""
        # XXX: we need to replace this list with and exception list
        # once bug #557718 is resolved
        msg = protocol_pb2.Message()
        msg.type = protocol_pb2.Message.ERROR
        msg.error.type = serr
        msg.error.comment = 'Error message for %s.' % serr
        exc = errors.error_to_exception(serr)(request=None, message=msg)

        request = self.fail_please(exc)
        kwargs = dict(request=request, request_error=SpecificException,
                      event_error='BAR', event_ok='FOO')
        d = self.action_queue._send_request_and_handle_errors(**kwargs)
        yield d

        self.assertEqual(( 'SYS_SERVER_ERROR', (), {'error': str(exc)}),
                     self.action_queue.event_queue.events[-1])

        # assert over logging
        self.assertTrue(len(self.handler.records) > 0)
        self.assertIn(request.__name__, self.handler.records[-1].message)
        self.assertIn(str(exc), self.handler.records[-1].message)

        # assert internal deferred was fired
        self.assertTrue(self.action_queue.deferred.called)
        self.assertIsInstance(self.action_queue.deferred.result, Failure)
        self.assertEqual(self.action_queue.deferred.result.value, exc)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_try_again(self):
        """_send_request_and_handle_errors is correct when server error."""
        serr = protocol_pb2.Error.TRY_AGAIN
        yield self.assert_send_request_and_handle_errors_on_server_error(serr)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_internal_error(self):
        """_send_request_and_handle_errors is correct when server error."""
        serr = protocol_pb2.Error.INTERNAL_ERROR
        yield self.assert_send_request_and_handle_errors_on_server_error(serr)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_protocol_error(self):
        """_send_request_and_handle_errors is correct when server error."""
        serr = protocol_pb2.Error.PROTOCOL_ERROR
        yield self.assert_send_request_and_handle_errors_on_server_error(serr)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_unsupported_version(self):
        """_send_request_and_handle_errors is correct when server error."""
        serr = protocol_pb2.Error.UNSUPPORTED_VERSION
        yield self.assert_send_request_and_handle_errors_on_server_error(serr)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_authetication_failed(self):
        """_send_request_and_handle_errors is correct when server error."""
        serr = protocol_pb2.Error.AUTHENTICATION_FAILED
        yield self.assert_send_request_and_handle_errors_on_server_error(serr)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_no_permission(self):
        """_send_request_and_handle_errors is correct when server error."""
        serr = protocol_pb2.Error.NO_PERMISSION
        yield self.assert_send_request_and_handle_errors_on_server_error(serr)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_already_exists(self):
        """_send_request_and_handle_errors is correct when server error."""
        serr = protocol_pb2.Error.ALREADY_EXISTS
        yield self.assert_send_request_and_handle_errors_on_server_error(serr)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_does_not_exist(self):
        """_send_request_and_handle_errors is correct when server error."""
        serr = protocol_pb2.Error.DOES_NOT_EXIST
        yield self.assert_send_request_and_handle_errors_on_server_error(serr)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_not_a_dir(self):
        """_send_request_and_handle_errors is correct when server error."""
        serr = protocol_pb2.Error.NOT_A_DIRECTORY
        yield self.assert_send_request_and_handle_errors_on_server_error(serr)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_not_empty(self):
        """_send_request_and_handle_errors is correct when server error."""
        serr = protocol_pb2.Error.NOT_EMPTY
        yield self.assert_send_request_and_handle_errors_on_server_error(serr)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_not_available(self):
        """_send_request_and_handle_errors is correct when server error."""
        serr = protocol_pb2.Error.NOT_AVAILABLE
        yield self.assert_send_request_and_handle_errors_on_server_error(serr)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_upload_in_porgress(self):
        """_send_request_and_handle_errors is correct when server error."""
        serr = protocol_pb2.Error.UPLOAD_IN_PROGRESS
        yield self.assert_send_request_and_handle_errors_on_server_error(serr)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_upload_corrupt(self):
        """_send_request_and_handle_errors is correct when server error."""
        serr = protocol_pb2.Error.UPLOAD_CORRUPT
        yield self.assert_send_request_and_handle_errors_on_server_error(serr)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_upload_canceled(self):
        """_send_request_and_handle_errors is correct when server error."""
        serr = protocol_pb2.Error.UPLOAD_CANCELED
        yield self.assert_send_request_and_handle_errors_on_server_error(serr)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_conflict(self):
        """_send_request_and_handle_errors is correct when server error."""
        serr = protocol_pb2.Error.CONFLICT
        yield self.assert_send_request_and_handle_errors_on_server_error(serr)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_quota_exceeded(self):
        """_send_request_and_handle_errors is correct when server error."""
        serr = protocol_pb2.Error.QUOTA_EXCEEDED
        yield self.assert_send_request_and_handle_errors_on_server_error(serr)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_invalid_filename(self):
        """_send_request_and_handle_errors is correct when server error."""
        serr = protocol_pb2.Error.INVALID_FILENAME
        yield self.assert_send_request_and_handle_errors_on_server_error(serr)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_unknown_error(self):
        """_send_request_and_handle_errors is correct when unknown error."""
        # XXX: we need to replace this list with and exception list
        # once bug #557718 is resolved
        serr = protocol_pb2.Error.AUTHENTICATION_REQUIRED
        msg = protocol_pb2.Message()
        msg.type = protocol_pb2.Message.ERROR
        msg.error.type = serr
        msg.error.comment = 'Error message for %s.' % serr
        exc = errors.error_to_exception(serr)(request=None, message=msg)

        request = self.fail_please(exc)
        kwargs = dict(request=request, request_error=SpecificException,
                  event_error='BAR', event_ok='FOO')
        d = self.action_queue._send_request_and_handle_errors(**kwargs)
        yield d

        self.assertIn(('SYS_UNKNOWN_ERROR', (), {}),
                      self.action_queue.event_queue.events)

        # assert over logging
        self.assertEqual(1, len(self.handler.records))
        self.assertIn(request.__name__, self.handler.records[0].message)
        self.assertIn(str(exc), self.handler.records[0].message)

        # assert internal deferred was fired
        self.assertTrue(self.action_queue.deferred.called)
        self.assertIsInstance(self.action_queue.deferred.result, Failure)
        self.assertEqual(self.action_queue.deferred.result.value, exc)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_no_protocol_error(self):
        """_send_request_and_handle_errors is correct when no-protocol error."""

        event = 'SYS_UNKNOWN_ERROR'
        error_msg = 'Error message for any Exception.'
        exc = Exception(error_msg)
        request = self.fail_please(exc)
        kwargs = dict(request=request, request_error=SpecificException,
                      event_error='BAR', event_ok='FOO')
        d = self.action_queue._send_request_and_handle_errors(**kwargs)
        yield d

        self.assertIn((event, (), {}),
                      self.action_queue.event_queue.events)

        # assert over logging
        self.assertEqual(1, len(self.handler.records))
        self.assertIn(request.__name__, self.handler.records[0].message)
        self.assertIn(str(exc), self.handler.records[0].message)

        # assert internal deferred was fired
        self.assertTrue(self.action_queue.deferred.called)
        self.assertIsInstance(self.action_queue.deferred.result, Failure)
        self.assertEqual(self.action_queue.deferred.result.value, exc)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_on_client_mismatch(self):
        """_send_request_and_handle_errors is correct when client mismatch."""

        def change_client(*args, **kwargs):
            """Change AQ's client while doing the request."""
            self.action_queue.client = object()

        self.action_queue.event_queue.events = [] # event cleanup
        kwargs = dict(request=change_client, request_error=SpecificException,
                      event_error='BAR', event_ok='FOO')
        d = self.action_queue._send_request_and_handle_errors(**kwargs)
        yield d

        self.assertEqual([], self.action_queue.event_queue.events)

        # assert over logging
        self.assertEqual(1, len(self.handler.records))
        self.assertIn(change_client.__name__, self.handler.records[0].message)
        self.assertIn('Client mismatch', self.handler.records[0].message)

        # assert internal deferred wasn't fired
        self.assertFalse(self.action_queue.deferred.called)

    @defer.inlineCallbacks
    def test_send_request_and_handle_errors_when_fire_deferred_false(self):
        """_send_request_and_handle_errors doesn't fire the deferred."""

        event = 'SYS_SPECIFIC_ERROR'
        EVENTS[event] = ('error',) # add event to the global valid events list
        self.addCleanup(lambda: EVENTS.pop(event))

        exc = SpecificException('The request failed!')
        request = self.fail_please(exc)
        kwargs = dict(request=request, request_error=SpecificException,
                      event_error=event, event_ok='YADDA_YADDA',
                      fire_deferred=False)
        d = self.action_queue._send_request_and_handle_errors(**kwargs)
        yield d

        # assert internal deferred wasn't fired
        self.assertFalse(self.action_queue.deferred.called)

    @defer.inlineCallbacks
    def test_check_version_when_unsupported_version_exception(self):
        """Test error handling after UnsupportedVersionError."""
        # raise a UnsupportedVersionError
        msg = protocol_pb2.Message()
        msg.type = protocol_pb2.Message.ERROR
        msg.error.type = protocol_pb2.Error.UNSUPPORTED_VERSION
        msg.error.comment = 'This is a funny comment.'
        exc = errors.UnsupportedVersionError(request=None, message=msg)

        self.action_queue.client.protocol_version = self.fail_please(exc)
        yield self.action_queue.check_version()
        event = ('SYS_PROTOCOL_VERSION_ERROR', (), {'error': str(exc)})
        self.assertEqual(event, self.action_queue.event_queue.events[-1])

    @defer.inlineCallbacks
    def test_set_capabilities_when_query_caps_not_accepted(self):
        """Test error handling when the query caps are not accepeted."""

        # query_caps returns False
        self.action_queue.client.query_caps = self.mock_caps(accepted=False)

        yield self.action_queue.set_capabilities(caps=None)
        msg = "The server doesn't have the requested capabilities"
        event = ('SYS_SET_CAPABILITIES_ERROR', (), {'error': msg})
        self.assertEqual(event, self.action_queue.event_queue.events[-1])
        self.assertNotIn(('SYS_SET_CAPABILITIES_OK', (), {}),
                          self.action_queue.event_queue.events)

        # assert internal deferred was fired
        self.assertTrue(self.action_queue.deferred.called)
        self.assertIsInstance(self.action_queue.deferred.result, Failure)
        self.assertIsInstance(self.action_queue.deferred.result.value,
                              StandardError)
        self.assertEqual(str(self.action_queue.deferred.result.value), msg)

    @defer.inlineCallbacks
    def test_set_capabilities_when_set_caps_not_accepted(self):
        """Test error handling when the query caps are not accepeted."""

        # query_caps returns True and set_caps returns False
        self.action_queue.client.query_caps = self.mock_caps(accepted=True)
        self.action_queue.client.set_caps = self.mock_caps(accepted=False)

        caps = 'very difficult cap'
        yield self.action_queue.set_capabilities(caps=caps)
        msg = "The server denied setting '%s' capabilities" % caps
        event = ('SYS_SET_CAPABILITIES_ERROR', (), {'error': msg})
        self.assertEqual(event, self.action_queue.event_queue.events[-1])
        self.assertNotIn(('SYS_SET_CAPABILITIES_OK', (), {}),
                          self.action_queue.event_queue.events)

        # assert internal deferred was fired
        self.assertTrue(self.action_queue.deferred.called)
        self.assertIsInstance(self.action_queue.deferred.result, Failure)
        self.assertIsInstance(self.action_queue.deferred.result.value,
                              StandardError)
        self.assertEqual(str(self.action_queue.deferred.result.value), msg)

    @defer.inlineCallbacks
    def test_set_capabilities_when_client_is_none(self):
        """Test error handling when the client is None."""

        self.action_queue.client = None

        yield self.action_queue.set_capabilities(caps=None)
        msg = "'NoneType' object has no attribute 'query_caps'"
        event = ('SYS_SET_CAPABILITIES_ERROR', (), {'error': msg})
        self.assertEqual(event, self.action_queue.event_queue.events[-1])
        self.assertNotIn(('SYS_SET_CAPABILITIES_OK', (), {}),
                          self.action_queue.event_queue.events)

        # assert internal deferred was fired
        self.assertTrue(self.action_queue.deferred.called)
        self.assertIsInstance(self.action_queue.deferred.result, Failure)
        self.assertIsInstance(self.action_queue.deferred.result.value,
                              StandardError)
        self.assertEqual(str(self.action_queue.deferred.result.value), msg)

    @defer.inlineCallbacks
    def test_set_capabilities_when_set_caps_is_accepted(self):
        """Test error handling when the query caps are not accepeted."""

        # query_caps returns True and set_caps returns True
        self.action_queue.client.query_caps = self.mock_caps(accepted=True)
        self.action_queue.client.set_caps = self.mock_caps(accepted=True)

        yield self.action_queue.set_capabilities(caps=None)
        event = ('SYS_SET_CAPABILITIES_OK', (), {})
        self.assertEqual(event, self.action_queue.event_queue.events[-1])

        # assert internal deferred wasn't fired
        self.assertFalse(self.action_queue.deferred.called)

    @defer.inlineCallbacks
    def test_authenticate_when_authenticated(self):
        """Test error handling after authenticate with no error."""

        self.action_queue.client.oauth_authenticate = \
            self.succeed_please(result=self.action_queue.client)
        yield self.action_queue.authenticate(oauth_consumer=object())
        event = ('SYS_AUTH_OK', (), {})
        self.assertEqual(event, self.action_queue.event_queue.events[-1])

        # assert internal deferred was fired with the client as result
        self.assertTrue(self.action_queue.deferred.called)
        self.assertTrue(self.action_queue.deferred.result is self.action_queue.client)

    @defer.inlineCallbacks
    def test_authenticate_when_authentication_failed_exception(self):
        """Test error handling after AuthenticationFailedError."""
        # raise a AuthenticationFailedError
        msg = protocol_pb2.Message()
        msg.type = protocol_pb2.Message.ERROR
        msg.error.type = protocol_pb2.Error.AUTHENTICATION_FAILED
        msg.error.comment = 'This is a funny comment.'
        exc = errors.AuthenticationFailedError(request=None, message=msg)

        self.action_queue.client.oauth_authenticate = self.fail_please(exc)
        yield self.action_queue.authenticate(oauth_consumer=object())
        event = ('SYS_AUTH_ERROR', (), {'error': str(exc)})
        self.assertEqual(event, self.action_queue.event_queue.events[-1])

        # assert internal deferred was fired
        self.assertTrue(self.action_queue.deferred.called)

    @defer.inlineCallbacks
    def test_server_rescan_as_a_whole(self):
        """Test error handling after server_rescan with no error."""

        def faked_get_root(marker):
            """Fake the action_queue.get_root."""
            root_id=object()
            self.action_queue.event_queue.push('SYS_ROOT_RECEIVED',
                                               root_id=root_id)
            return root_id

        self.patch(self.action_queue, 'get_root', faked_get_root)

        self.action_queue.client.query = \
            self.succeed_please(result=self.action_queue.client)
        yield self.action_queue.server_rescan(root_mdid=object(), data_gen=list)
        event = ('SYS_SERVER_RESCAN_DONE', (), {})
        self.assertEqual(event, self.action_queue.event_queue.events[-1])

        # assert internal deferred wasn't fired
        self.assertFalse(self.action_queue.deferred.called)

    @defer.inlineCallbacks
    def test_server_rescan_when_get_root_fails(self):
        """Test error handling after server_rescan when get_root fails."""

        msg = protocol_pb2.Message()
        msg.type = protocol_pb2.Message.ERROR
        msg.error.type = protocol_pb2.Error.PROTOCOL_ERROR
        msg.error.comment = 'get_root failed'
        exc = errors.StorageRequestError(request=None, message=msg)
        self.patch(self.action_queue, 'get_root', self.fail_please(exc))

        self.action_queue.client.query = self.fail_please(NotImplementedError())

        yield self.action_queue.server_rescan(root_mdid=object(), data_gen=list)

        event = ('SYS_SERVER_RESCAN_DONE', (), {})
        self.assertNotIn(event, self.action_queue.event_queue.events)

        event = ('SYS_SERVER_ERROR', (), {'error': str(exc)})
        self.assertEqual(event, self.action_queue.event_queue.events[-1])

        # assert internal deferred wasn't fired
        self.assertFalse(self.action_queue.deferred.called)

    @defer.inlineCallbacks
    def test_server_rescan_when_query_fails(self):
        """Test error handling after server_rescan when query fails."""

        self.patch(self.action_queue, 'get_root',
                   self.succeed_please(result=object()))

        msg = protocol_pb2.Message()
        msg.type = protocol_pb2.Message.ERROR
        msg.error.type = protocol_pb2.Error.PROTOCOL_ERROR
        msg.error.comment = 'query failed'
        exc = errors.StorageRequestError(request=None, message=msg)
        self.action_queue.client.query = self.fail_please(exc)

        yield self.action_queue.server_rescan(root_mdid=object(), data_gen=list)

        event = ('SYS_SERVER_RESCAN_DONE', (), {})
        self.assertNotIn(event, self.action_queue.event_queue.events)

        event = ('SYS_SERVER_ERROR', (), {'error': str(exc)})
        self.assertEqual(event, self.action_queue.event_queue.events[-1])

        # assert internal deferred wasn't fired
        self.assertFalse(self.action_queue.deferred.called)
