From 0ee675dad1e60c7b10bd99fe7614f5eee676db0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=ADn=20Gait=C3=A1n?= Date: Fri, 13 Jan 2023 12:40:22 -0300 Subject: [PATCH 1/5] support networks as host argument --- pytest_socket.py | 15 ++++++++++++++- tests/test_restrict_hosts.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/pytest_socket.py b/pytest_socket.py index 1f7160b..7eee3b9 100644 --- a/pytest_socket.py +++ b/pytest_socket.py @@ -1,3 +1,4 @@ +import ipaddress import socket import pytest @@ -179,7 +180,7 @@ def socket_allow_hosts(allowed=None, allow_unix_socket=False): def guarded_connect(inst, *args): host = host_from_connect_args(args) - if host in allowed or (_is_unix_socket(inst.family) and allow_unix_socket): + if is_valid_host(host, allowed) or (_is_unix_socket(inst.family) and allow_unix_socket): return _true_connect(inst, *args) raise SocketConnectBlockedError(allowed, host) @@ -191,3 +192,15 @@ def _remove_restrictions(): """restore socket.socket.* to allow access to the Internet. useful in testing.""" socket.socket = _true_socket socket.socket.connect = _true_connect + + +def is_valid_host(host, allowed): + ips = [ip for ip in allowed if "/" not in ip] + if host in ips: + return True + networks = [ipaddress.ip_network(mask) for mask in allowed if "/" in mask] + ip = ipaddress.ip_address(host) + for net in networks: + if ip in net: + return True + return False diff --git a/tests/test_restrict_hosts.py b/tests/test_restrict_hosts.py index e1b721b..5efb258 100644 --- a/tests/test_restrict_hosts.py +++ b/tests/test_restrict_hosts.py @@ -1,4 +1,5 @@ import inspect +from urllib.parse import urlparse import pytest @@ -256,3 +257,33 @@ def test_fail_2(): result.assert_outcomes(1, 0, 2) assert_host_blocked(result, "2.2.2.2") assert_host_blocked(result, httpbin.host) + + +def test_cidr_allow(testdir, httpbin): + test_url = urlparse(httpbin.url) + testdir.makepyfile( + """ + import pytest + import socket + @pytest.mark.allow_hosts('127.0.0.0/8') + def test_pass(): + socket.socket().connect(('{0}', {1})) + @pytest.mark.allow_hosts('127.0.0.0/16') + def test_pass_2(): + socket.socket().connect(('{0}', {1})) + def test_fail(): + socket.socket().connect(('2.2.2.2', {1})) + def test_fail_2(): + socket.socket().connect(('192.168.1.10', {1})) + @pytest.mark.allow_hosts('172.20.0.0/16') + def test_fail_3(): + socket.socket().connect(('{0}', {1})) + """.format( + test_url.hostname, test_url.port + ) + ) + result = testdir.runpytest("--verbose", "--allow-hosts=1.2.3.4") + result.assert_outcomes(2, 0, 3) + assert_host_blocked(result, "2.2.2.2") + assert_host_blocked(result, "192.168.1.10") + assert_host_blocked(result, test_url.hostname) From 8ff4fc226033a52d0845ac2bcd0665ae59f7571a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=ADn=20Gait=C3=A1n?= Date: Fri, 13 Jan 2023 12:44:04 -0300 Subject: [PATCH 2/5] no host shortcut --- pytest_socket.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytest_socket.py b/pytest_socket.py index 7eee3b9..270e02b 100644 --- a/pytest_socket.py +++ b/pytest_socket.py @@ -195,6 +195,8 @@ def _remove_restrictions(): def is_valid_host(host, allowed): + if not host: + return ips = [ip for ip in allowed if "/" not in ip] if host in ips: return True From c49e8fc98a4ba60c08ea4d09430f38a9859e5970 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=ADn=20Gait=C3=A1n?= Date: Fri, 13 Jan 2023 12:52:02 -0300 Subject: [PATCH 3/5] black --- pytest_socket.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytest_socket.py b/pytest_socket.py index 270e02b..2059191 100644 --- a/pytest_socket.py +++ b/pytest_socket.py @@ -180,7 +180,9 @@ def socket_allow_hosts(allowed=None, allow_unix_socket=False): def guarded_connect(inst, *args): host = host_from_connect_args(args) - if is_valid_host(host, allowed) or (_is_unix_socket(inst.family) and allow_unix_socket): + if is_valid_host(host, allowed) or ( + _is_unix_socket(inst.family) and allow_unix_socket + ): return _true_connect(inst, *args) raise SocketConnectBlockedError(allowed, host) From c02ebe06b239b9f9a7ec39345df1e52c73f3cd6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=ADn=20Gait=C3=A1n?= Date: Fri, 13 Jan 2023 13:03:38 -0300 Subject: [PATCH 4/5] more shortcuts --- pytest_socket.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pytest_socket.py b/pytest_socket.py index 2059191..a5a15ea 100644 --- a/pytest_socket.py +++ b/pytest_socket.py @@ -198,11 +198,16 @@ def _remove_restrictions(): def is_valid_host(host, allowed): if not host: - return + return False + ips = [ip for ip in allowed if "/" not in ip] if host in ips: return True + networks = [ipaddress.ip_network(mask) for mask in allowed if "/" in mask] + if not networks: + return False + ip = ipaddress.ip_address(host) for net in networks: if ip in net: From 893803353b9ccc775ba112c531e17a94aaf5a9f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=ADn=20Gait=C3=A1n?= Date: Fri, 13 Jan 2023 14:16:07 -0300 Subject: [PATCH 5/5] flake8 --- pytest_socket.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytest_socket.py b/pytest_socket.py index a5a15ea..a5ce505 100644 --- a/pytest_socket.py +++ b/pytest_socket.py @@ -199,15 +199,15 @@ def _remove_restrictions(): def is_valid_host(host, allowed): if not host: return False - + ips = [ip for ip in allowed if "/" not in ip] if host in ips: return True networks = [ipaddress.ip_network(mask) for mask in allowed if "/" in mask] - if not networks: + if not networks: return False - + ip = ipaddress.ip_address(host) for net in networks: if ip in net: