Skip to content

PYTHON-4441 Use deferred imports instead of lazy module loading #1648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
address review
  • Loading branch information
blink1073 committed May 30, 2024
commit c267283f6aa2ada5acee02f2354208efe19add2a
12 changes: 5 additions & 7 deletions pymongo/auth_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""MONGODB-AWS Authentication helpers."""
from __future__ import annotations

import importlib.util
from typing import TYPE_CHECKING, Any, Mapping, Type

import bson
Expand All @@ -27,19 +26,18 @@
from pymongo.auth import MongoCredential
from pymongo.pool import Connection

_HAVE_MONGODB_AWS = importlib.util.find_spec("pymongo_auth_aws") is not None


def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None:
"""Authenticate using MONGODB-AWS."""
if not _HAVE_MONGODB_AWS:
try:
import pymongo_auth_aws # type:ignore[import]
except ImportError:
raise ConfigurationError(
"MONGODB-AWS authentication requires pymongo-auth-aws: "
"install with: python -m pip install 'pymongo[aws]'"
)
) from None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the from None so the error shows the ImportError.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


# Delayed imports.
import pymongo_auth_aws # type:ignore[import]
# Delayed import.
from pymongo_auth_aws.auth import ( # type:ignore[import]
set_cached_credentials,
set_use_cached_credentials,
Expand Down
38 changes: 30 additions & 8 deletions pymongo/compression_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,44 @@
# limitations under the License.
from __future__ import annotations

import importlib.util
import warnings
from typing import Any, Iterable, Optional, Union

from pymongo.hello import HelloCompat
from pymongo.helpers import _SENSITIVE_COMMANDS

_HAVE_SNAPPY = importlib.util.find_spec("snappy") is not None
_HAVE_ZLIB = importlib.util.find_spec("zlib") is not None
_HAVE_ZSTD = importlib.util.find_spec("zstandard") is not None

_SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"}
_NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD}
_NO_COMPRESSION.update(_SENSITIVE_COMMANDS)


def _have_snappy() -> bool:
try:
import snappy # type:ignore[import] # noqa: F401

return True
except ImportError:
return False


def _have_zlib() -> bool:
try:
import zlib # noqa: F401

return True
except ImportError:
return False


def _have_zstd() -> bool:
try:
import zstandard # type:ignore[import] # noqa: F401

return True
except ImportError:
return False


def validate_compressors(dummy: Any, value: Union[str, Iterable[str]]) -> list[str]:
try:
# `value` is string.
Expand All @@ -41,21 +63,21 @@ def validate_compressors(dummy: Any, value: Union[str, Iterable[str]]) -> list[s
if compressor not in _SUPPORTED_COMPRESSORS:
compressors.remove(compressor)
warnings.warn(f"Unsupported compressor: {compressor}", stacklevel=2)
elif compressor == "snappy" and not _HAVE_SNAPPY:
elif compressor == "snappy" and not _have_snappy():
compressors.remove(compressor)
warnings.warn(
"Wire protocol compression with snappy is not available. "
"You must install the python-snappy module for snappy support.",
stacklevel=2,
)
elif compressor == "zlib" and not _HAVE_ZLIB:
elif compressor == "zlib" and not _have_zlib():
compressors.remove(compressor)
warnings.warn(
"Wire protocol compression with zlib is not available. "
"The zlib module is not available.",
stacklevel=2,
)
elif compressor == "zstd" and not _HAVE_ZSTD:
elif compressor == "zstd" and not _have_zstd():
compressors.remove(compressor)
warnings.warn(
"Wire protocol compression with zstandard is not available. "
Expand Down
6 changes: 2 additions & 4 deletions pymongo/pyopenssl_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ipaddress import ip_address as _ip_address
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union

import cryptography.x509 as x509 # type:ignore[import]
from OpenSSL import SSL as _SSL
from OpenSSL import crypto as _crypto

Expand All @@ -39,7 +40,6 @@
if TYPE_CHECKING:
from ssl import VerifyMode

from cryptography.x509 import Certificate

_T = TypeVar("_T")

Expand Down Expand Up @@ -179,7 +179,7 @@ class _CallbackData:
"""Data class which is passed to the OCSP callback."""

def __init__(self) -> None:
self.trusted_ca_certs: Optional[list[Certificate]] = None
self.trusted_ca_certs: Optional[list[x509.Certificate]] = None
self.check_ocsp_endpoint: Optional[bool] = None
self.ocsp_response_cache = _OCSPCache()

Expand Down Expand Up @@ -331,7 +331,6 @@ def _load_wincerts(self, store: str) -> None:
"""Attempt to load CA certs from Windows trust store."""
cert_store = self._ctx.get_cert_store()
oid = _stdlibssl.Purpose.SERVER_AUTH.oid
import cryptography.x509 as x509

for cert, encoding, trust in _stdlibssl.enum_certificates(store): # type: ignore
if encoding == "x509_asn":
Expand Down Expand Up @@ -401,7 +400,6 @@ def wrap_socket(
# XXX: Do this in a callback registered with
# SSLContext.set_info_callback? See Twisted for an example.
if self.check_hostname and server_hostname is not None:
import service_identity
import service_identity.pyopenssl

try:
Expand Down
10 changes: 8 additions & 2 deletions pymongo/srv_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Support for resolving hosts and options from mongodb+srv:// URIs."""
from __future__ import annotations

import importlib.util
import ipaddress
import random
from typing import TYPE_CHECKING, Any, Optional, Union
Expand All @@ -26,7 +25,14 @@
if TYPE_CHECKING:
from dns import resolver

_HAVE_DNSPYTHON = importlib.util.find_spec("dns") is not None

def _have_dnspython() -> bool:
try:
import dns # type:ignore[import] # noqa: F401

return True
except ImportError:
return False


# dnspython can return bytes or str from various parts
Expand Down
4 changes: 2 additions & 2 deletions pymongo/uri_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
get_validated_options,
)
from pymongo.errors import ConfigurationError, InvalidURI
from pymongo.srv_resolver import _HAVE_DNSPYTHON, _SrvResolver
from pymongo.srv_resolver import _have_dnspython, _SrvResolver
from pymongo.typings import _Address

if TYPE_CHECKING:
Expand Down Expand Up @@ -472,7 +472,7 @@ def parse_uri(
is_srv = False
scheme_free = uri[SCHEME_LEN:]
elif uri.startswith(SRV_SCHEME):
if not _HAVE_DNSPYTHON:
if not _have_dnspython():
python_path = sys.executable or "python"
raise ConfigurationError(
'The "dnspython" module must be '
Expand Down
6 changes: 3 additions & 3 deletions test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
from pymongo.client_options import ClientOptions
from pymongo.command_cursor import CommandCursor
from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT
from pymongo.compression_support import _HAVE_SNAPPY, _HAVE_ZSTD
from pymongo.compression_support import _have_snappy, _have_zstd
from pymongo.cursor import Cursor, CursorType
from pymongo.database import Database
from pymongo.driver_info import DriverInfo
Expand Down Expand Up @@ -1558,7 +1558,7 @@ def compression_settings(client):
self.assertEqual(opts.compressors, ["zlib"])
self.assertEqual(opts.zlib_compression_level, -1)

if not _HAVE_SNAPPY:
if not _have_snappy():
uri = "mongodb://localhost:27017/?compressors=snappy"
client = MongoClient(uri, connect=False)
opts = compression_settings(client)
Expand All @@ -1573,7 +1573,7 @@ def compression_settings(client):
opts = compression_settings(client)
self.assertEqual(opts.compressors, ["snappy", "zlib"])

if not _HAVE_ZSTD:
if not _have_zstd():
uri = "mongodb://localhost:27017/?compressors=zstd"
client = MongoClient(uri, connect=False)
opts = compression_settings(client)
Expand Down
4 changes: 2 additions & 2 deletions test/test_srv_polling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pymongo import common
from pymongo.errors import ConfigurationError
from pymongo.mongo_client import MongoClient
from pymongo.srv_resolver import _HAVE_DNSPYTHON
from pymongo.srv_resolver import _have_dnspython

WAIT_TIME = 0.1

Expand Down Expand Up @@ -148,7 +148,7 @@ def predicate():
return True

def run_scenario(self, dns_response, expect_change):
self.assertEqual(_HAVE_DNSPYTHON, True)
self.assertEqual(_have_dnspython(), True)
if callable(dns_response):
dns_resolver_response = dns_response
else:
Expand Down
4 changes: 2 additions & 2 deletions test/test_uri_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from test import clear_warning_registry, unittest

from pymongo.common import INTERNAL_URI_OPTION_NAME_MAP, validate
from pymongo.compression_support import _HAVE_SNAPPY
from pymongo.compression_support import _have_snappy
from pymongo.uri_parser import SRV_SCHEME, parse_uri

CONN_STRING_TEST_PATH = os.path.join(
Expand Down Expand Up @@ -95,7 +95,7 @@ def modified_test_scenario(*args, **kwargs):
def create_test(test, test_workdir):
def run_scenario(self):
compressors = (test.get("options") or {}).get("compressors", [])
if "snappy" in compressors and not _HAVE_SNAPPY:
if "snappy" in compressors and not _have_snappy:
self.skipTest("This test needs the snappy module.")
valid = True
warning = False
Expand Down
Loading