File: //proc/thread-self/root/usr/local/lib/python3.8/dist-packages/trio/tests/test_highlevel_socket.py
import pytest
import sys
import socket as stdlib_socket
import errno
from .. import _core
from ..testing import (
    check_half_closeable_stream,
    wait_all_tasks_blocked,
    assert_checkpoints,
)
from .._highlevel_socket import *
from .. import socket as tsocket
async def test_SocketStream_basics():
    # stdlib socket bad (even if connected)
    a, b = stdlib_socket.socketpair()
    with a, b:
        with pytest.raises(TypeError):
            SocketStream(a)
    # DGRAM socket bad
    with tsocket.socket(type=tsocket.SOCK_DGRAM) as sock:
        with pytest.raises(ValueError):
            SocketStream(sock)
    a, b = tsocket.socketpair()
    with a, b:
        s = SocketStream(a)
        assert s.socket is a
    # Use a real, connected socket to test socket options, because
    # socketpair() might give us a unix socket that doesn't support any of
    # these options
    with tsocket.socket() as listen_sock:
        await listen_sock.bind(("127.0.0.1", 0))
        listen_sock.listen(1)
        with tsocket.socket() as client_sock:
            await client_sock.connect(listen_sock.getsockname())
            s = SocketStream(client_sock)
            # TCP_NODELAY enabled by default
            assert s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
            # We can disable it though
            s.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False)
            assert not s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
            b = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1)
            assert isinstance(b, bytes)
async def test_SocketStream_send_all():
    BIG = 10000000
    a_sock, b_sock = tsocket.socketpair()
    with a_sock, b_sock:
        a = SocketStream(a_sock)
        b = SocketStream(b_sock)
        # Check a send_all that has to be split into multiple parts (on most
        # platforms... on Windows every send() either succeeds or fails as a
        # whole)
        async def sender():
            data = bytearray(BIG)
            await a.send_all(data)
            # send_all uses memoryviews internally, which temporarily "lock"
            # the object they view. If it doesn't clean them up properly, then
            # some bytearray operations might raise an error afterwards, which
            # would be a pretty weird and annoying side-effect to spring on
            # users. So test that this doesn't happen, by forcing the
            # bytearray's underlying buffer to be realloc'ed:
            data += bytes(BIG)
            # (Note: the above line of code doesn't do a very good job at
            # testing anything, because:
            # - on CPython, the refcount GC generally cleans up memoryviews
            #   for us even if we're sloppy.
            # - on PyPy3, at least as of 5.7.0, the memoryview code and the
            #   bytearray code conspire so that resizing never fails – if
            #   resizing forces the bytearray's internal buffer to move, then
            #   all memoryview references are automagically updated (!!).
            #   See:
            #   https://gist.github.com/njsmith/0ffd38ec05ad8e34004f34a7dc492227
            # But I'm leaving the test here in hopes that if this ever changes
            # and we break our implementation of send_all, then we'll get some
            # early warning...)
        async def receiver():
            # Make sure the sender fills up the kernel buffers and blocks
            await wait_all_tasks_blocked()
            nbytes = 0
            while nbytes < BIG:
                nbytes += len(await b.receive_some(BIG))
            assert nbytes == BIG
        async with _core.open_nursery() as nursery:
            nursery.start_soon(sender)
            nursery.start_soon(receiver)
        # We know that we received BIG bytes of NULs so far. Make sure that
        # was all the data in there.
        await a.send_all(b"e")
        assert await b.receive_some(10) == b"e"
        await a.send_eof()
        assert await b.receive_some(10) == b""
async def fill_stream(s):
    async def sender():
        while True:
            await s.send_all(b"x" * 10000)
    async def waiter(nursery):
        await wait_all_tasks_blocked()
        nursery.cancel_scope.cancel()
    async with _core.open_nursery() as nursery:
        nursery.start_soon(sender)
        nursery.start_soon(waiter, nursery)
async def test_SocketStream_generic():
    async def stream_maker():
        left, right = tsocket.socketpair()
        return SocketStream(left), SocketStream(right)
    async def clogged_stream_maker():
        left, right = await stream_maker()
        await fill_stream(left)
        await fill_stream(right)
        return left, right
    await check_half_closeable_stream(stream_maker, clogged_stream_maker)
async def test_SocketListener():
    # Not a Trio socket
    with stdlib_socket.socket() as s:
        s.bind(("127.0.0.1", 0))
        s.listen(10)
        with pytest.raises(TypeError):
            SocketListener(s)
    # Not a SOCK_STREAM
    with tsocket.socket(type=tsocket.SOCK_DGRAM) as s:
        await s.bind(("127.0.0.1", 0))
        with pytest.raises(ValueError) as excinfo:
            SocketListener(s)
        excinfo.match(r".*SOCK_STREAM")
    # Didn't call .listen()
    # macOS has no way to check for this, so skip testing it there.
    if sys.platform != "darwin":
        with tsocket.socket() as s:
            await s.bind(("127.0.0.1", 0))
            with pytest.raises(ValueError) as excinfo:
                SocketListener(s)
            excinfo.match(r".*listen")
    listen_sock = tsocket.socket()
    await listen_sock.bind(("127.0.0.1", 0))
    listen_sock.listen(10)
    listener = SocketListener(listen_sock)
    assert listener.socket is listen_sock
    client_sock = tsocket.socket()
    await client_sock.connect(listen_sock.getsockname())
    with assert_checkpoints():
        server_stream = await listener.accept()
    assert isinstance(server_stream, SocketStream)
    assert server_stream.socket.getsockname() == listen_sock.getsockname()
    assert server_stream.socket.getpeername() == client_sock.getsockname()
    with assert_checkpoints():
        await listener.aclose()
    with assert_checkpoints():
        await listener.aclose()
    with assert_checkpoints():
        with pytest.raises(_core.ClosedResourceError):
            await listener.accept()
    client_sock.close()
    await server_stream.aclose()
async def test_SocketListener_socket_closed_underfoot():
    listen_sock = tsocket.socket()
    await listen_sock.bind(("127.0.0.1", 0))
    listen_sock.listen(10)
    listener = SocketListener(listen_sock)
    # Close the socket, not the listener
    listen_sock.close()
    # SocketListener gives correct error
    with assert_checkpoints():
        with pytest.raises(_core.ClosedResourceError):
            await listener.accept()
async def test_SocketListener_accept_errors():
    class FakeSocket(tsocket.SocketType):
        def __init__(self, events):
            self._events = iter(events)
        type = tsocket.SOCK_STREAM
        # Fool the check for SO_ACCEPTCONN in SocketListener.__init__
        def getsockopt(self, level, opt):
            return True
        def setsockopt(self, level, opt, value):
            pass
        async def accept(self):
            await _core.checkpoint()
            event = next(self._events)
            if isinstance(event, BaseException):
                raise event
            else:
                return event, None
    fake_server_sock = FakeSocket([])
    fake_listen_sock = FakeSocket(
        [
            OSError(errno.ECONNABORTED, "Connection aborted"),
            OSError(errno.EPERM, "Permission denied"),
            OSError(errno.EPROTO, "Bad protocol"),
            fake_server_sock,
            OSError(errno.EMFILE, "Out of file descriptors"),
            OSError(errno.EFAULT, "attempt to write to read-only memory"),
            OSError(errno.ENOBUFS, "out of buffers"),
            fake_server_sock,
        ]
    )
    l = SocketListener(fake_listen_sock)
    with assert_checkpoints():
        s = await l.accept()
        assert s.socket is fake_server_sock
    for code in [errno.EMFILE, errno.EFAULT, errno.ENOBUFS]:
        with assert_checkpoints():
            with pytest.raises(OSError) as excinfo:
                await l.accept()
            assert excinfo.value.errno == code
    with assert_checkpoints():
        s = await l.accept()
        assert s.socket is fake_server_sock
async def test_socket_stream_works_when_peer_has_already_closed():
    sock_a, sock_b = tsocket.socketpair()
    with sock_a, sock_b:
        await sock_b.send(b"x")
        sock_b.close()
        stream = SocketStream(sock_a)
        assert await stream.receive_some(1) == b"x"
        assert await stream.receive_some(1) == b""