File: //proc/self/root/lib/python3/dist-packages/twisted/conch/test/test_connection.py
# Copyright (c) 2007-2010 Twisted Matrix Laboratories.
# See LICENSE for details
"""
This module tests twisted.conch.ssh.connection.
"""
import struct
from twisted.conch.ssh import channel
from twisted.conch.test import test_userauth
from twisted.python.reflect import requireModule
from twisted.trial import unittest
cryptography = requireModule("cryptography")
from twisted.conch import error
if cryptography:
    from twisted.conch.ssh import common, connection
else:
    class connection:  # type: ignore[no-redef]
        class SSHConnection:
            pass
class TestChannel(channel.SSHChannel):
    """
    A mocked-up version of twisted.conch.ssh.channel.SSHChannel.
    @ivar gotOpen: True if channelOpen has been called.
    @type gotOpen: L{bool}
    @ivar specificData: the specific channel open data passed to channelOpen.
    @type specificData: L{bytes}
    @ivar openFailureReason: the reason passed to openFailed.
    @type openFailed: C{error.ConchError}
    @ivar inBuffer: a C{list} of strings received by the channel.
    @type inBuffer: C{list}
    @ivar extBuffer: a C{list} of 2-tuples (type, extended data) of received by
        the channel.
    @type extBuffer: C{list}
    @ivar numberRequests: the number of requests that have been made to this
        channel.
    @type numberRequests: L{int}
    @ivar gotEOF: True if the other side sent EOF.
    @type gotEOF: L{bool}
    @ivar gotOneClose: True if the other side closed the connection.
    @type gotOneClose: L{bool}
    @ivar gotClosed: True if the channel is closed.
    @type gotClosed: L{bool}
    """
    name = b"TestChannel"
    gotOpen = False
    gotClosed = False
    def logPrefix(self):
        return "TestChannel %i" % self.id
    def channelOpen(self, specificData):
        """
        The channel is open.  Set up the instance variables.
        """
        self.gotOpen = True
        self.specificData = specificData
        self.inBuffer = []
        self.extBuffer = []
        self.numberRequests = 0
        self.gotEOF = False
        self.gotOneClose = False
        self.gotClosed = False
    def openFailed(self, reason):
        """
        Opening the channel failed.  Store the reason why.
        """
        self.openFailureReason = reason
    def request_test(self, data):
        """
        A test request.  Return True if data is 'data'.
        @type data: L{bytes}
        """
        self.numberRequests += 1
        return data == b"data"
    def dataReceived(self, data):
        """
        Data was received.  Store it in the buffer.
        """
        self.inBuffer.append(data)
    def extReceived(self, code, data):
        """
        Extended data was received.  Store it in the buffer.
        """
        self.extBuffer.append((code, data))
    def eofReceived(self):
        """
        EOF was received.  Remember it.
        """
        self.gotEOF = True
    def closeReceived(self):
        """
        Close was received.  Remember it.
        """
        self.gotOneClose = True
    def closed(self):
        """
        The channel is closed.  Rembember it.
        """
        self.gotClosed = True
class TestAvatar:
    """
    A mocked-up version of twisted.conch.avatar.ConchUser
    """
    _ARGS_ERROR_CODE = 123
    def lookupChannel(self, channelType, windowSize, maxPacket, data):
        """
        The server wants us to return a channel.  If the requested channel is
        our TestChannel, return it, otherwise return None.
        """
        if channelType == TestChannel.name:
            return TestChannel(
                remoteWindow=windowSize,
                remoteMaxPacket=maxPacket,
                data=data,
                avatar=self,
            )
        elif channelType == b"conch-error-args":
            # Raise a ConchError with backwards arguments to make sure the
            # connection fixes it for us.  This case should be deprecated and
            # deleted eventually, but only after all of Conch gets the argument
            # order right.
            raise error.ConchError(self._ARGS_ERROR_CODE, "error args in wrong order")
    def gotGlobalRequest(self, requestType, data):
        """
        The client has made a global request.  If the global request is
        'TestGlobal', return True.  If the global request is 'TestData',
        return True and the request-specific data we received.  Otherwise,
        return False.
        """
        if requestType == b"TestGlobal":
            return True
        elif requestType == b"TestData":
            return True, data
        else:
            return False
class TestConnection(connection.SSHConnection):
    """
    A subclass of SSHConnection for testing.
    @ivar channel: the current channel.
    @type channel. C{TestChannel}
    """
    if not cryptography:
        skip = "Cannot run without cryptography"
    def logPrefix(self):
        return "TestConnection"
    def global_TestGlobal(self, data):
        """
        The other side made the 'TestGlobal' global request.  Return True.
        """
        return True
    def global_Test_Data(self, data):
        """
        The other side made the 'Test-Data' global request.  Return True and
        the data we received.
        """
        return True, data
    def channel_TestChannel(self, windowSize, maxPacket, data):
        """
        The other side is requesting the TestChannel.  Create a C{TestChannel}
        instance, store it, and return it.
        """
        self.channel = TestChannel(
            remoteWindow=windowSize, remoteMaxPacket=maxPacket, data=data
        )
        return self.channel
    def channel_ErrorChannel(self, windowSize, maxPacket, data):
        """
        The other side is requesting the ErrorChannel.  Raise an exception.
        """
        raise AssertionError("no such thing")
class ConnectionTests(unittest.TestCase):
    if not cryptography:
        skip = "Cannot run without cryptography"
    def setUp(self):
        self.transport = test_userauth.FakeTransport(None)
        self.transport.avatar = TestAvatar()
        self.conn = TestConnection()
        self.conn.transport = self.transport
        self.conn.serviceStarted()
    def _openChannel(self, channel):
        """
        Open the channel with the default connection.
        """
        self.conn.openChannel(channel)
        self.transport.packets = self.transport.packets[:-1]
        self.conn.ssh_CHANNEL_OPEN_CONFIRMATION(
            struct.pack(">2L", channel.id, 255) + b"\x00\x02\x00\x00\x00\x00\x80\x00"
        )
    def tearDown(self):
        self.conn.serviceStopped()
    def test_linkAvatar(self):
        """
        Test that the connection links itself to the avatar in the
        transport.
        """
        self.assertIs(self.transport.avatar.conn, self.conn)
    def test_serviceStopped(self):
        """
        Test that serviceStopped() closes any open channels.
        """
        channel1 = TestChannel()
        channel2 = TestChannel()
        self.conn.openChannel(channel1)
        self.conn.openChannel(channel2)
        self.conn.ssh_CHANNEL_OPEN_CONFIRMATION(b"\x00\x00\x00\x00" * 4)
        self.assertTrue(channel1.gotOpen)
        self.assertFalse(channel1.gotClosed)
        self.assertFalse(channel2.gotOpen)
        self.assertFalse(channel2.gotClosed)
        self.conn.serviceStopped()
        self.assertTrue(channel1.gotClosed)
        self.assertFalse(channel2.gotOpen)
        self.assertFalse(channel2.gotClosed)
        from twisted.internet.error import ConnectionLost
        self.assertIsInstance(channel2.openFailureReason, ConnectionLost)
    def test_GLOBAL_REQUEST(self):
        """
        Test that global request packets are dispatched to the global_*
        methods and the return values are translated into success or failure
        messages.
        """
        self.conn.ssh_GLOBAL_REQUEST(common.NS(b"TestGlobal") + b"\xff")
        self.assertEqual(
            self.transport.packets, [(connection.MSG_REQUEST_SUCCESS, b"")]
        )
        self.transport.packets = []
        self.conn.ssh_GLOBAL_REQUEST(common.NS(b"TestData") + b"\xff" + b"test data")
        self.assertEqual(
            self.transport.packets, [(connection.MSG_REQUEST_SUCCESS, b"test data")]
        )
        self.transport.packets = []
        self.conn.ssh_GLOBAL_REQUEST(common.NS(b"TestBad") + b"\xff")
        self.assertEqual(
            self.transport.packets, [(connection.MSG_REQUEST_FAILURE, b"")]
        )
        self.transport.packets = []
        self.conn.ssh_GLOBAL_REQUEST(common.NS(b"TestGlobal") + b"\x00")
        self.assertEqual(self.transport.packets, [])
    def test_REQUEST_SUCCESS(self):
        """
        Test that global request success packets cause the Deferred to be
        called back.
        """
        d = self.conn.sendGlobalRequest(b"request", b"data", True)
        self.conn.ssh_REQUEST_SUCCESS(b"data")
        def check(data):
            self.assertEqual(data, b"data")
        d.addCallback(check)
        d.addErrback(self.fail)
        return d
    def test_REQUEST_FAILURE(self):
        """
        Test that global request failure packets cause the Deferred to be
        erred back.
        """
        d = self.conn.sendGlobalRequest(b"request", b"data", True)
        self.conn.ssh_REQUEST_FAILURE(b"data")
        def check(f):
            self.assertEqual(f.value.data, b"data")
        d.addCallback(self.fail)
        d.addErrback(check)
        return d
    def test_CHANNEL_OPEN(self):
        """
        Test that open channel packets cause a channel to be created and
        opened or a failure message to be returned.
        """
        del self.transport.avatar
        self.conn.ssh_CHANNEL_OPEN(common.NS(b"TestChannel") + b"\x00\x00\x00\x01" * 4)
        self.assertTrue(self.conn.channel.gotOpen)
        self.assertEqual(self.conn.channel.conn, self.conn)
        self.assertEqual(self.conn.channel.data, b"\x00\x00\x00\x01")
        self.assertEqual(self.conn.channel.specificData, b"\x00\x00\x00\x01")
        self.assertEqual(self.conn.channel.remoteWindowLeft, 1)
        self.assertEqual(self.conn.channel.remoteMaxPacket, 1)
        self.assertEqual(
            self.transport.packets,
            [
                (
                    connection.MSG_CHANNEL_OPEN_CONFIRMATION,
                    b"\x00\x00\x00\x01\x00\x00\x00\x00\x00\x02\x00\x00"
                    b"\x00\x00\x80\x00",
                )
            ],
        )
        self.transport.packets = []
        self.conn.ssh_CHANNEL_OPEN(common.NS(b"BadChannel") + b"\x00\x00\x00\x02" * 4)
        self.flushLoggedErrors()
        self.assertEqual(
            self.transport.packets,
            [
                (
                    connection.MSG_CHANNEL_OPEN_FAILURE,
                    b"\x00\x00\x00\x02\x00\x00\x00\x03"
                    + common.NS(b"unknown channel")
                    + common.NS(b""),
                )
            ],
        )
        self.transport.packets = []
        self.conn.ssh_CHANNEL_OPEN(common.NS(b"ErrorChannel") + b"\x00\x00\x00\x02" * 4)
        self.flushLoggedErrors()
        self.assertEqual(
            self.transport.packets,
            [
                (
                    connection.MSG_CHANNEL_OPEN_FAILURE,
                    b"\x00\x00\x00\x02\x00\x00\x00\x02"
                    + common.NS(b"unknown failure")
                    + common.NS(b""),
                )
            ],
        )
    def _lookupChannelErrorTest(self, code):
        """
        Deliver a request for a channel open which will result in an exception
        being raised during channel lookup.  Assert that an error response is
        delivered as a result.
        """
        self.transport.avatar._ARGS_ERROR_CODE = code
        self.conn.ssh_CHANNEL_OPEN(
            common.NS(b"conch-error-args") + b"\x00\x00\x00\x01" * 4
        )
        errors = self.flushLoggedErrors(error.ConchError)
        self.assertEqual(len(errors), 1, f"Expected one error, got: {errors!r}")
        self.assertEqual(errors[0].value.args, (123, "error args in wrong order"))
        self.assertEqual(
            self.transport.packets,
            [
                (
                    connection.MSG_CHANNEL_OPEN_FAILURE,
                    # The response includes some bytes which identifying the
                    # associated request, as well as the error code (7b in hex) and
                    # the error message.
                    b"\x00\x00\x00\x01\x00\x00\x00\x7b"
                    + common.NS(b"error args in wrong order")
                    + common.NS(b""),
                )
            ],
        )
    def test_lookupChannelError(self):
        """
        If a C{lookupChannel} implementation raises L{error.ConchError} with the
        arguments in the wrong order, a C{MSG_CHANNEL_OPEN} failure is still
        sent in response to the message.
        This is a temporary work-around until L{error.ConchError} is given
        better attributes and all of the Conch code starts constructing
        instances of it properly.  Eventually this functionality should be
        deprecated and then removed.
        """
        self._lookupChannelErrorTest(123)
    def test_CHANNEL_OPEN_CONFIRMATION(self):
        """
        Test that channel open confirmation packets cause the channel to be
        notified that it's open.
        """
        channel = TestChannel()
        self.conn.openChannel(channel)
        self.conn.ssh_CHANNEL_OPEN_CONFIRMATION(b"\x00\x00\x00\x00" * 5)
        self.assertEqual(channel.remoteWindowLeft, 0)
        self.assertEqual(channel.remoteMaxPacket, 0)
        self.assertEqual(channel.specificData, b"\x00\x00\x00\x00")
        self.assertEqual(self.conn.channelsToRemoteChannel[channel], 0)
        self.assertEqual(self.conn.localToRemoteChannel[0], 0)
    def test_CHANNEL_OPEN_FAILURE(self):
        """
        Test that channel open failure packets cause the channel to be
        notified that its opening failed.
        """
        channel = TestChannel()
        self.conn.openChannel(channel)
        self.conn.ssh_CHANNEL_OPEN_FAILURE(
            b"\x00\x00\x00\x00\x00\x00\x00" b"\x01" + common.NS(b"failure!")
        )
        self.assertEqual(channel.openFailureReason.args, (b"failure!", 1))
        self.assertIsNone(self.conn.channels.get(channel))
    def test_CHANNEL_WINDOW_ADJUST(self):
        """
        Test that channel window adjust messages add bytes to the channel
        window.
        """
        channel = TestChannel()
        self._openChannel(channel)
        oldWindowSize = channel.remoteWindowLeft
        self.conn.ssh_CHANNEL_WINDOW_ADJUST(b"\x00\x00\x00\x00\x00\x00\x00" b"\x01")
        self.assertEqual(channel.remoteWindowLeft, oldWindowSize + 1)
    def test_CHANNEL_DATA(self):
        """
        Test that channel data messages are passed up to the channel, or
        cause the channel to be closed if the data is too large.
        """
        channel = TestChannel(localWindow=6, localMaxPacket=5)
        self._openChannel(channel)
        self.conn.ssh_CHANNEL_DATA(b"\x00\x00\x00\x00" + common.NS(b"data"))
        self.assertEqual(channel.inBuffer, [b"data"])
        self.assertEqual(
            self.transport.packets,
            [
                (
                    connection.MSG_CHANNEL_WINDOW_ADJUST,
                    b"\x00\x00\x00\xff" b"\x00\x00\x00\x04",
                )
            ],
        )
        self.transport.packets = []
        longData = b"a" * (channel.localWindowLeft + 1)
        self.conn.ssh_CHANNEL_DATA(b"\x00\x00\x00\x00" + common.NS(longData))
        self.assertEqual(channel.inBuffer, [b"data"])
        self.assertEqual(
            self.transport.packets,
            [(connection.MSG_CHANNEL_CLOSE, b"\x00\x00\x00\xff")],
        )
        channel = TestChannel()
        self._openChannel(channel)
        bigData = b"a" * (channel.localMaxPacket + 1)
        self.transport.packets = []
        self.conn.ssh_CHANNEL_DATA(b"\x00\x00\x00\x01" + common.NS(bigData))
        self.assertEqual(channel.inBuffer, [])
        self.assertEqual(
            self.transport.packets,
            [(connection.MSG_CHANNEL_CLOSE, b"\x00\x00\x00\xff")],
        )
    def test_CHANNEL_EXTENDED_DATA(self):
        """
        Test that channel extended data messages are passed up to the channel,
        or cause the channel to be closed if they're too big.
        """
        channel = TestChannel(localWindow=6, localMaxPacket=5)
        self._openChannel(channel)
        self.conn.ssh_CHANNEL_EXTENDED_DATA(
            b"\x00\x00\x00\x00\x00\x00\x00" b"\x00" + common.NS(b"data")
        )
        self.assertEqual(channel.extBuffer, [(0, b"data")])
        self.assertEqual(
            self.transport.packets,
            [
                (
                    connection.MSG_CHANNEL_WINDOW_ADJUST,
                    b"\x00\x00\x00\xff" b"\x00\x00\x00\x04",
                )
            ],
        )
        self.transport.packets = []
        longData = b"a" * (channel.localWindowLeft + 1)
        self.conn.ssh_CHANNEL_EXTENDED_DATA(
            b"\x00\x00\x00\x00\x00\x00\x00" b"\x00" + common.NS(longData)
        )
        self.assertEqual(channel.extBuffer, [(0, b"data")])
        self.assertEqual(
            self.transport.packets,
            [(connection.MSG_CHANNEL_CLOSE, b"\x00\x00\x00\xff")],
        )
        channel = TestChannel()
        self._openChannel(channel)
        bigData = b"a" * (channel.localMaxPacket + 1)
        self.transport.packets = []
        self.conn.ssh_CHANNEL_EXTENDED_DATA(
            b"\x00\x00\x00\x01\x00\x00\x00" b"\x00" + common.NS(bigData)
        )
        self.assertEqual(channel.extBuffer, [])
        self.assertEqual(
            self.transport.packets,
            [(connection.MSG_CHANNEL_CLOSE, b"\x00\x00\x00\xff")],
        )
    def test_CHANNEL_EOF(self):
        """
        Test that channel eof messages are passed up to the channel.
        """
        channel = TestChannel()
        self._openChannel(channel)
        self.conn.ssh_CHANNEL_EOF(b"\x00\x00\x00\x00")
        self.assertTrue(channel.gotEOF)
    def test_CHANNEL_CLOSE(self):
        """
        Test that channel close messages are passed up to the channel.  Also,
        test that channel.close() is called if both sides are closed when this
        message is received.
        """
        channel = TestChannel()
        self._openChannel(channel)
        self.assertTrue(channel.gotOpen)
        self.assertFalse(channel.gotOneClose)
        self.assertFalse(channel.gotClosed)
        self.conn.sendClose(channel)
        self.conn.ssh_CHANNEL_CLOSE(b"\x00\x00\x00\x00")
        self.assertTrue(channel.gotOneClose)
        self.assertTrue(channel.gotClosed)
    def test_CHANNEL_REQUEST_success(self):
        """
        Test that channel requests that succeed send MSG_CHANNEL_SUCCESS.
        """
        channel = TestChannel()
        self._openChannel(channel)
        self.conn.ssh_CHANNEL_REQUEST(
            b"\x00\x00\x00\x00" + common.NS(b"test") + b"\x00"
        )
        self.assertEqual(channel.numberRequests, 1)
        d = self.conn.ssh_CHANNEL_REQUEST(
            b"\x00\x00\x00\x00" + common.NS(b"test") + b"\xff" + b"data"
        )
        def check(result):
            self.assertEqual(
                self.transport.packets,
                [(connection.MSG_CHANNEL_SUCCESS, b"\x00\x00\x00\xff")],
            )
        d.addCallback(check)
        return d
    def test_CHANNEL_REQUEST_failure(self):
        """
        Test that channel requests that fail send MSG_CHANNEL_FAILURE.
        """
        channel = TestChannel()
        self._openChannel(channel)
        d = self.conn.ssh_CHANNEL_REQUEST(
            b"\x00\x00\x00\x00" + common.NS(b"test") + b"\xff"
        )
        def check(result):
            self.assertEqual(
                self.transport.packets,
                [(connection.MSG_CHANNEL_FAILURE, b"\x00\x00\x00\xff")],
            )
        d.addCallback(self.fail)
        d.addErrback(check)
        return d
    def test_CHANNEL_REQUEST_SUCCESS(self):
        """
        Test that channel request success messages cause the Deferred to be
        called back.
        """
        channel = TestChannel()
        self._openChannel(channel)
        d = self.conn.sendRequest(channel, b"test", b"data", True)
        self.conn.ssh_CHANNEL_SUCCESS(b"\x00\x00\x00\x00")
        def check(result):
            self.assertTrue(result)
        return d
    def test_CHANNEL_REQUEST_FAILURE(self):
        """
        Test that channel request failure messages cause the Deferred to be
        erred back.
        """
        channel = TestChannel()
        self._openChannel(channel)
        d = self.conn.sendRequest(channel, b"test", b"", True)
        self.conn.ssh_CHANNEL_FAILURE(b"\x00\x00\x00\x00")
        def check(result):
            self.assertEqual(result.value.value, "channel request failed")
        d.addCallback(self.fail)
        d.addErrback(check)
        return d
    def test_sendGlobalRequest(self):
        """
        Test that global request messages are sent in the right format.
        """
        d = self.conn.sendGlobalRequest(b"wantReply", b"data", True)
        # must be added to prevent errbacking during teardown
        d.addErrback(lambda failure: None)
        self.conn.sendGlobalRequest(b"noReply", b"", False)
        self.assertEqual(
            self.transport.packets,
            [
                (connection.MSG_GLOBAL_REQUEST, common.NS(b"wantReply") + b"\xffdata"),
                (connection.MSG_GLOBAL_REQUEST, common.NS(b"noReply") + b"\x00"),
            ],
        )
        self.assertEqual(self.conn.deferreds, {"global": [d]})
    def test_openChannel(self):
        """
        Test that open channel messages are sent in the right format.
        """
        channel = TestChannel()
        self.conn.openChannel(channel, b"aaaa")
        self.assertEqual(
            self.transport.packets,
            [
                (
                    connection.MSG_CHANNEL_OPEN,
                    common.NS(b"TestChannel")
                    + b"\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x80\x00aaaa",
                )
            ],
        )
        self.assertEqual(channel.id, 0)
        self.assertEqual(self.conn.localChannelID, 1)
    def test_sendRequest(self):
        """
        Test that channel request messages are sent in the right format.
        """
        channel = TestChannel()
        self._openChannel(channel)
        d = self.conn.sendRequest(channel, b"test", b"test", True)
        # needed to prevent errbacks during teardown.
        d.addErrback(lambda failure: None)
        self.conn.sendRequest(channel, b"test2", b"", False)
        channel.localClosed = True  # emulate sending a close message
        self.conn.sendRequest(channel, b"test3", b"", True)
        self.assertEqual(
            self.transport.packets,
            [
                (
                    connection.MSG_CHANNEL_REQUEST,
                    b"\x00\x00\x00\xff" + common.NS(b"test") + b"\x01test",
                ),
                (
                    connection.MSG_CHANNEL_REQUEST,
                    b"\x00\x00\x00\xff" + common.NS(b"test2") + b"\x00",
                ),
            ],
        )
        self.assertEqual(self.conn.deferreds[0], [d])
    def test_adjustWindow(self):
        """
        Test that channel window adjust messages cause bytes to be added
        to the window.
        """
        channel = TestChannel(localWindow=5)
        self._openChannel(channel)
        channel.localWindowLeft = 0
        self.conn.adjustWindow(channel, 1)
        self.assertEqual(channel.localWindowLeft, 1)
        channel.localClosed = True
        self.conn.adjustWindow(channel, 2)
        self.assertEqual(channel.localWindowLeft, 1)
        self.assertEqual(
            self.transport.packets,
            [
                (
                    connection.MSG_CHANNEL_WINDOW_ADJUST,
                    b"\x00\x00\x00\xff" b"\x00\x00\x00\x01",
                )
            ],
        )
    def test_sendData(self):
        """
        Test that channel data messages are sent in the right format.
        """
        channel = TestChannel()
        self._openChannel(channel)
        self.conn.sendData(channel, b"a")
        channel.localClosed = True
        self.conn.sendData(channel, b"b")
        self.assertEqual(
            self.transport.packets,
            [(connection.MSG_CHANNEL_DATA, b"\x00\x00\x00\xff" + common.NS(b"a"))],
        )
    def test_sendExtendedData(self):
        """
        Test that channel extended data messages are sent in the right format.
        """
        channel = TestChannel()
        self._openChannel(channel)
        self.conn.sendExtendedData(channel, 1, b"test")
        channel.localClosed = True
        self.conn.sendExtendedData(channel, 2, b"test2")
        self.assertEqual(
            self.transport.packets,
            [
                (
                    connection.MSG_CHANNEL_EXTENDED_DATA,
                    b"\x00\x00\x00\xff" + b"\x00\x00\x00\x01" + common.NS(b"test"),
                )
            ],
        )
    def test_sendEOF(self):
        """
        Test that channel EOF messages are sent in the right format.
        """
        channel = TestChannel()
        self._openChannel(channel)
        self.conn.sendEOF(channel)
        self.assertEqual(
            self.transport.packets, [(connection.MSG_CHANNEL_EOF, b"\x00\x00\x00\xff")]
        )
        channel.localClosed = True
        self.conn.sendEOF(channel)
        self.assertEqual(
            self.transport.packets, [(connection.MSG_CHANNEL_EOF, b"\x00\x00\x00\xff")]
        )
    def test_sendClose(self):
        """
        Test that channel close messages are sent in the right format.
        """
        channel = TestChannel()
        self._openChannel(channel)
        self.conn.sendClose(channel)
        self.assertTrue(channel.localClosed)
        self.assertEqual(
            self.transport.packets,
            [(connection.MSG_CHANNEL_CLOSE, b"\x00\x00\x00\xff")],
        )
        self.conn.sendClose(channel)
        self.assertEqual(
            self.transport.packets,
            [(connection.MSG_CHANNEL_CLOSE, b"\x00\x00\x00\xff")],
        )
        channel2 = TestChannel()
        self._openChannel(channel2)
        self.assertTrue(channel2.gotOpen)
        self.assertFalse(channel2.gotClosed)
        channel2.remoteClosed = True
        self.conn.sendClose(channel2)
        self.assertTrue(channel2.gotClosed)
    def test_getChannelWithAvatar(self):
        """
        Test that getChannel dispatches to the avatar when an avatar is
        present. Correct functioning without the avatar is verified in
        test_CHANNEL_OPEN.
        """
        channel = self.conn.getChannel(b"TestChannel", 50, 30, b"data")
        self.assertEqual(channel.data, b"data")
        self.assertEqual(channel.remoteWindowLeft, 50)
        self.assertEqual(channel.remoteMaxPacket, 30)
        self.assertRaises(
            error.ConchError, self.conn.getChannel, b"BadChannel", 50, 30, b"data"
        )
    def test_gotGlobalRequestWithoutAvatar(self):
        """
        Test that gotGlobalRequests dispatches to global_* without an avatar.
        """
        del self.transport.avatar
        self.assertTrue(self.conn.gotGlobalRequest(b"TestGlobal", b"data"))
        self.assertEqual(
            self.conn.gotGlobalRequest(b"Test-Data", b"data"), (True, b"data")
        )
        self.assertFalse(self.conn.gotGlobalRequest(b"BadGlobal", b"data"))
    def test_channelClosedCausesLeftoverChannelDeferredsToErrback(self):
        """
        Whenever an SSH channel gets closed any Deferred that was returned by a
        sendRequest() on its parent connection must be errbacked.
        """
        channel = TestChannel()
        self._openChannel(channel)
        d = self.conn.sendRequest(channel, b"dummyrequest", b"dummydata", wantReply=1)
        d = self.assertFailure(d, error.ConchError)
        self.conn.channelClosed(channel)
        return d
class CleanConnectionShutdownTests(unittest.TestCase):
    """
    Check whether correct cleanup is performed on connection shutdown.
    """
    if not cryptography:
        skip = "Cannot run without cryptography"
    def setUp(self):
        self.transport = test_userauth.FakeTransport(None)
        self.transport.avatar = TestAvatar()
        self.conn = TestConnection()
        self.conn.transport = self.transport
    def test_serviceStoppedCausesLeftoverGlobalDeferredsToErrback(self):
        """
        Once the service is stopped any leftover global deferred returned by
        a sendGlobalRequest() call must be errbacked.
        """
        self.conn.serviceStarted()
        d = self.conn.sendGlobalRequest(b"dummyrequest", b"dummydata", wantReply=1)
        d = self.assertFailure(d, error.ConchError)
        self.conn.serviceStopped()
        return d