Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ jobs:
run: uv python install ${{ matrix.python-version }}
- name: Prepare project for development
run: uv sync --extra dev
- name: Check typing
run: uv run mypy src
- name: Test with pytest
run: |
uv run coverage run -m pytest
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dev = [
"requests >= 2.32.4",
"starlette >= 0.47.1",
"httpx >= 0.28.1",
"mypy >= 1.19.1",
]

[project.urls]
Expand All @@ -54,6 +55,9 @@ known_first_party = ['pytest_socket', 'conftest']
# https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html#profilegcm
profile = "black"

[tool.mypy]
strict = true

[tool.vulture]
ignore_decorators = ["@pytest.fixture"]
ignore_names = ["pytest_*"]
Expand Down
78 changes: 49 additions & 29 deletions src/pytest_socket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,37 @@
import socket
import typing
from collections import defaultdict
from collections.abc import Iterator
from dataclasses import dataclass, field

import pytest
import typing_extensions

_true_socket = socket.socket
_true_connect = socket.socket.connect


class SocketBlockedError(RuntimeError):
def __init__(self, *_args, **_kwargs):
def __init__(self, *_args: typing.Any, **_kwargs: typing.Any):
super().__init__("A test tried to use socket.socket.")


class SocketConnectBlockedError(RuntimeError):
def __init__(self, allowed, host, *_args, **_kwargs):
if allowed:
allowed = ",".join(allowed)
def __init__(
self,
allowed: typing.List[str],
host: typing.Union[str, None],
*_args: typing.Any,
**_kwargs: typing.Any,
):
allowed_str = ",".join(allowed)
super().__init__(
"A test tried to use socket.socket.connect() "
f'with host "{host}" (allowed: "{allowed}").'
f'with host "{host}" (allowed: "{allowed_str}").'
)


def pytest_addoption(parser):
def pytest_addoption(parser: pytest.Parser) -> None:
group = parser.getgroup("socket")
group.addoption(
"--disable-socket",
Expand Down Expand Up @@ -55,15 +62,15 @@ def pytest_addoption(parser):


@pytest.fixture
def socket_disabled(pytestconfig):
def socket_disabled(pytestconfig: pytest.Config) -> Iterator[None]:
"""disable socket.socket for duration of this test function"""
socket_config = pytestconfig.stash[_STASH_KEY]
disable_socket(allow_unix_socket=socket_config.allow_unix_socket)
yield


@pytest.fixture
def socket_enabled(pytestconfig):
def socket_enabledpy(pytestconfig: pytest.Config) -> Iterator[None]:
"""enable socket.socket for duration of this test function"""
enable_socket()
yield
Expand All @@ -81,7 +88,7 @@ class _PytestSocketConfig:
_STASH_KEY = pytest.StashKey[_PytestSocketConfig]()


def _is_unix_socket(family) -> bool:
def _is_unix_socket(family: int) -> bool:
try:
is_unix_socket = family == socket.AF_UNIX
except AttributeError:
Expand All @@ -90,27 +97,35 @@ def _is_unix_socket(family) -> bool:
return is_unix_socket


def disable_socket(allow_unix_socket=False):
def disable_socket(allow_unix_socket: bool = False) -> None:
"""disable socket.socket to disable the Internet. useful in testing."""

class GuardedSocket(socket.socket):
"""socket guard to disable socket creation (from pytest-socket)"""

def __new__(cls, family=-1, type=-1, proto=-1, fileno=None):
def __new__(
cls,
family: typing.Union[socket.AddressFamily, int] = -1,
type: typing.Union[socket.SocketKind, int] = -1,
proto: int = -1,
fileno: typing.Union[int, None] = None,
) -> typing_extensions.Self:
if _is_unix_socket(family) and allow_unix_socket:
return super().__new__(cls, family, type, proto, fileno)
return super().__new__(
cls, family, type, proto, fileno # type: ignore[call-arg]
)

raise SocketBlockedError()

socket.socket = GuardedSocket
socket.socket = GuardedSocket # type: ignore[misc]


def enable_socket():
def enable_socket() -> None:
"""re-enable socket.socket to enable the Internet. useful in testing."""
socket.socket = _true_socket
socket.socket = _true_socket # type: ignore[misc]


def pytest_configure(config):
def pytest_configure(config: pytest.Config) -> None:
config.addinivalue_line(
"markers", "disable_socket(): Disable socket connections for a specific test"
)
Expand All @@ -131,7 +146,7 @@ def pytest_configure(config):
)


def pytest_runtest_setup(item) -> None:
def pytest_runtest_setup(item: pytest.Item) -> None:
"""During each test item's setup phase,
choose the behavior based on the configurations supplied.

Expand Down Expand Up @@ -172,7 +187,9 @@ def pytest_runtest_setup(item) -> None:
disable_socket(socket_config.allow_unix_socket)


def _resolve_allow_hosts(item):
def _resolve_allow_hosts(
item: pytest.Item,
) -> typing.Union[str, typing.List[str], None]:
"""Resolve `allow_hosts` behaviors."""
socket_config = item.config.stash[_STASH_KEY]

Expand All @@ -192,21 +209,23 @@ def _resolve_allow_hosts(item):
return hosts


def pytest_runtest_teardown():
def pytest_runtest_teardown() -> None:
_remove_restrictions()


def host_from_address(address):
def host_from_address(address: tuple[typing.Any, ...]) -> typing.Union[str, None]:
host = address[0]
if isinstance(host, str):
return host
return None


def host_from_connect_args(args):
def host_from_connect_args(args: tuple[typing.Any, ...]) -> typing.Union[str, None]:
address = args[0]

if isinstance(address, tuple):
return host_from_address(address)
return None


def is_ipaddress(address: str) -> bool:
Expand All @@ -223,15 +242,16 @@ def is_ipaddress(address: str) -> bool:
def resolve_hostnames(hostname: str) -> typing.Set[str]:
try:
return {
addr_struct[0] for *_, addr_struct in socket.getaddrinfo(hostname, None)
addr_struct[0] # type: ignore[misc]
for *_, addr_struct in socket.getaddrinfo(hostname, None)
}
except socket.gaierror:
return set()


def normalize_allowed_hosts(
allowed_hosts: typing.List[str],
resolution_cache: typing.Optional[typing.Dict[str, typing.List[str]]] = None,
resolution_cache: typing.Optional[typing.Dict[str, typing.Set[str]]] = None,
) -> typing.Dict[str, typing.Set[str]]:
"""Map all items in `allowed_hosts` to IP addresses."""
if resolution_cache is None:
Expand All @@ -252,7 +272,7 @@ def normalize_allowed_hosts(
def socket_allow_hosts(
allowed: typing.Union[str, typing.List[str], None] = None,
allow_unix_socket: bool = False,
resolution_cache: typing.Optional[typing.Dict[str, typing.List[str]]] = None,
resolution_cache: typing.Optional[typing.Dict[str, typing.Set[str]]] = None,
) -> None:
"""disable socket.socket.connect() to disable the Internet. useful in testing."""
if isinstance(allowed, str):
Expand All @@ -276,7 +296,7 @@ def socket_allow_hosts(
]
)

def guarded_connect(inst, *args):
def guarded_connect(inst: socket.socket, *args: typing.Any) -> None:
host = host_from_connect_args(args)
if host in allowed_ip_hosts_and_hostnames or (
_is_unix_socket(inst.family) and allow_unix_socket
Expand All @@ -285,10 +305,10 @@ def guarded_connect(inst, *args):

raise SocketConnectBlockedError(allowed_list, host)

socket.socket.connect = guarded_connect
socket.socket.connect = guarded_connect # type: ignore[assignment,method-assign]


def _remove_restrictions():
def _remove_restrictions() -> None:
"""restore socket.socket.* to allow access to the Internet. useful in testing."""
socket.socket = _true_socket
socket.socket.connect = _true_connect
socket.socket = _true_socket # type: ignore[misc]
socket.socket.connect = _true_connect # type: ignore[method-assign]
Empty file added src/pytest_socket/py.typed
Empty file.
Loading
Loading