Skip to content

Commit 2a0717b

Browse files
authored
FIX-modin-project#4048: support sqlalchemy objects in con parameter for to_sql (modin-project#5940)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent e1d4241 commit 2a0717b

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

modin/core/io/sql/sql_dispatcher.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,17 @@ def _read(cls, sql, con, index_col=None, **kwargs):
110110
new_frame.synchronize_labels(axis=0)
111111
return cls.query_compiler_cls(new_frame)
112112

113+
@classmethod
114+
def _is_supported_sqlalchemy_object(cls, obj): # noqa: GL08
115+
supported = None
116+
try:
117+
import sqlalchemy as sa
118+
119+
supported = isinstance(obj, (sa.engine.Engine, sa.engine.Connection))
120+
except ImportError:
121+
supported = False
122+
return supported
123+
113124
@classmethod
114125
def write(cls, qc, **kwargs):
115126
"""
@@ -128,6 +139,17 @@ def write(cls, qc, **kwargs):
128139
# since the mapping operation is non-blocking, each partition will return an empty DF
129140
# so at the end, the blocking operation will be this empty DF to_pandas
130141

142+
if not isinstance(
143+
kwargs["con"], str
144+
) and not cls._is_supported_sqlalchemy_object(kwargs["con"]):
145+
return cls.base_io.to_sql(qc, **kwargs)
146+
147+
# In the case that we are given a SQLAlchemy Connection or Engine, the objects
148+
# are not pickleable. We have to convert it to the URL string and connect from
149+
# each of the workers.
150+
if cls._is_supported_sqlalchemy_object(kwargs["con"]):
151+
kwargs["con"] = str(kwargs["con"].engine.url)
152+
131153
empty_df = qc.getitem_row_array([0]).to_pandas().head(0)
132154
empty_df.to_sql(**kwargs)
133155
# so each partition will append its respective DF

modin/pandas/test/test_io.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2344,19 +2344,28 @@ def test_read_sql_with_chunksize(self, make_sql_connection):
23442344
df_equals(modin_df, pandas_df)
23452345

23462346
@pytest.mark.parametrize("index", [False, True])
2347-
def test_to_sql(self, tmp_path, make_sql_connection, index):
2347+
@pytest.mark.parametrize("conn_type", ["str", "sqlalchemy", "sqlalchemy+connect"])
2348+
def test_to_sql(self, tmp_path, make_sql_connection, index, conn_type):
23482349
table_name = f"test_to_sql_{str(index)}"
23492350
modin_df, pandas_df = create_test_dfs(TEST_DATA)
23502351

23512352
# We do not pass the table name so the fixture won't generate a table
23522353
conn = make_sql_connection(tmp_path / f"{table_name}_modin.db")
2354+
if conn_type.startswith("sqlalchemy"):
2355+
conn = sa.create_engine(conn)
2356+
if conn_type == "sqlalchemy+connect":
2357+
conn = conn.connect()
23532358
modin_df.to_sql(table_name, conn, index=index)
23542359
df_modin_sql = pandas.read_sql(
23552360
table_name, con=conn, index_col="index" if index else None
23562361
)
23572362

23582363
# We do not pass the table name so the fixture won't generate a table
23592364
conn = make_sql_connection(tmp_path / f"{table_name}_pandas.db")
2365+
if conn_type.startswith("sqlalchemy"):
2366+
conn = sa.create_engine(conn)
2367+
if conn_type == "sqlalchemy+connect":
2368+
conn = conn.connect()
23602369
pandas_df.to_sql(table_name, conn, index=index)
23612370
df_pandas_sql = pandas.read_sql(
23622371
table_name, con=conn, index_col="index" if index else None

0 commit comments

Comments
 (0)