From ee08d8ebcd8d9ab4537f967c68d5fe0d41c409d8 Mon Sep 17 00:00:00 2001 From: jguer Date: Fri, 5 Nov 2021 16:06:16 +0100 Subject: [PATCH] feat(allow_hosts): add support for hostname resolution --- pytest_socket.py | 36 +++++++++++++++++++++++++++++++++++- tests/test_restrict_hosts.py | 8 ++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/pytest_socket.py b/pytest_socket.py index 5e4ca3c..7c53480 100644 --- a/pytest_socket.py +++ b/pytest_socket.py @@ -139,17 +139,51 @@ def host_from_connect_args(args): return host_from_address(address) +def is_ipaddress(address: str): + """ + Determine if the address is a valid IPv4 address. + """ + try: + socket.inet_aton(address) + return True + except socket.error: + return False + + +def resolve_hostname(hostname): + try: + return socket.gethostbyname(hostname) + except socket.gaierror: + return None + + +def treat_allowed(allowed): + allowed_hosts = [] + for allow in allowed: + allow = allow.strip() + if is_ipaddress(allow): + allowed_hosts.append(allow) + + resolved = resolve_hostname(allow) + if resolved: + allowed_hosts.append(resolved) + return allowed_hosts + + def socket_allow_hosts(allowed=None): """ disable socket.socket.connect() to disable the Internet. useful in testing. """ if isinstance(allowed, str): allowed = allowed.split(',') + if not isinstance(allowed, list): return + allowed_hosts = treat_allowed(allowed) + def guarded_connect(inst, *args): host = host_from_connect_args(args) - if host and host in allowed: + if host and host in allowed_hosts: return _true_connect(inst, *args) raise SocketConnectBlockedError(allowed, host) diff --git a/tests/test_restrict_hosts.py b/tests/test_restrict_hosts.py index 76c6bf2..fde1af1 100644 --- a/tests/test_restrict_hosts.py +++ b/tests/test_restrict_hosts.py @@ -106,6 +106,14 @@ def test_single_cli_arg_connect_enabled(assert_connect): assert_connect(True, cli_arg=localhost) +def test_single_cli_arg_connect_enabled_hostname_resolved(assert_connect): + assert_connect(True, cli_arg="localhost") + + +def test_single_cli_arg_connect_enabled_hostname_unresolvable(assert_connect): + assert_connect(False, cli_arg="unresolvable") + + def test_single_cli_arg_connect_unicode_enabled(assert_connect): assert_connect(True, cli_arg=localhost, code_template=connect_unicode_code_template)