diff --git a/R/Connection.R b/R/Connection.R index 05c685cd..071494be 100644 --- a/R/Connection.R +++ b/R/Connection.R @@ -508,12 +508,19 @@ setMethod( #' @param params Optional query parameters, passed on to [dbBind()] #' @param immediate If `TRUE`, SQLExecDirect will be used instead of #' SQLPrepare, and the `params` argument is ignored +#' @param timeout Number of seconds to wait before aborting the query. +#' The default, `Inf`, will never timeout. #' @export setMethod( "dbSendQuery", c("OdbcConnection", "character"), - function(conn, statement, params = NULL, ..., immediate = FALSE) { - res <- OdbcResult(connection = conn, statement = statement, params = params, immediate = immediate) - res + function(conn, statement, params = NULL, ..., immediate = FALSE, timeout = Inf) { + OdbcResult( + connection = conn, + statement = statement, + params = params, + immediate = immediate, + timeout = timeout + ) }) #' @rdname OdbcConnection @@ -522,9 +529,14 @@ setMethod( #' @export setMethod( "dbSendStatement", c("OdbcConnection", "character"), - function(conn, statement, params = NULL, ..., immediate = FALSE) { - res <- OdbcResult(connection = conn, statement = statement, params = params, immediate = immediate) - res + function(conn, statement, params = NULL, ..., immediate = FALSE, timeout = Inf) { + OdbcResult( + connection = conn, + statement = statement, + params = params, + immediate = immediate, + timeout = timeout + ) }) #' @rdname OdbcConnection @@ -659,8 +671,8 @@ setMethod( #' @inheritParams DBI::dbFetch #' @export setMethod("dbGetQuery", signature("OdbcConnection", "character"), - function(conn, statement, n = -1, params = NULL, ...) { - rs <- dbSendQuery(conn, statement, params = params, ...) + function(conn, statement, ..., n = -1, params = NULL, timeout = Inf) { + rs <- dbSendQuery(conn, statement, params = params, timeout = timeout, ...) on.exit(dbClearResult(rs)) df <- dbFetch(rs, n = n, ...) diff --git a/R/RcppExports.R b/R/RcppExports.R index 7047d58c..05ac7889 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -85,8 +85,8 @@ result_completed <- function(r) { .Call(`_odbc_result_completed`, r) } -new_result <- function(p, sql, immediate) { - .Call(`_odbc_new_result`, p, sql, immediate) +new_result <- function(p, sql, immediate, timeout) { + .Call(`_odbc_new_result`, p, sql, immediate, timeout) } result_fetch <- function(r, n_max = -1L) { diff --git a/R/Result.R b/R/Result.R index 2b5898df..4edc1d20 100644 --- a/R/Result.R +++ b/R/Result.R @@ -9,11 +9,20 @@ NULL #' @docType methods NULL -OdbcResult <- function(connection, statement, params = NULL, immediate = FALSE) { +OdbcResult <- function(connection, statement, params = NULL, immediate = FALSE, timeout = Inf) { if (nzchar(connection@encoding)) { statement <- enc2iconv(statement, connection@encoding) } - ptr <- new_result(connection@ptr, statement, immediate) + if (is.infinite(timeout)) { + timeout <- 0L + } + + ptr <- new_result( + p = connection@ptr, + sql = statement, + immediate = immediate, + timeout = timeout + ) res <- new("OdbcResult", connection = connection, statement = statement, ptr = ptr) if (!is.null(params)) { diff --git a/man/OdbcConnection.Rd b/man/OdbcConnection.Rd index 7751c69d..8d6f229e 100644 --- a/man/OdbcConnection.Rd +++ b/man/OdbcConnection.Rd @@ -30,9 +30,23 @@ \S4method{dbDisconnect}{OdbcConnection}(conn, ...) -\S4method{dbSendQuery}{OdbcConnection,character}(conn, statement, params = NULL, ..., immediate = FALSE) - -\S4method{dbSendStatement}{OdbcConnection,character}(conn, statement, params = NULL, ..., immediate = FALSE) +\S4method{dbSendQuery}{OdbcConnection,character}( + conn, + statement, + params = NULL, + ..., + immediate = FALSE, + timeout = Inf +) + +\S4method{dbSendStatement}{OdbcConnection,character}( + conn, + statement, + params = NULL, + ..., + immediate = FALSE, + timeout = Inf +) \S4method{dbDataType}{OdbcConnection,ANY}(dbObj, obj, ...) @@ -46,7 +60,7 @@ \S4method{dbGetInfo}{OdbcConnection}(dbObj, ...) -\S4method{dbGetQuery}{OdbcConnection,character}(conn, statement, n = -1, params = NULL, ...) +\S4method{dbGetQuery}{OdbcConnection,character}(conn, statement, ..., n = -1, params = NULL, timeout = Inf) \S4method{dbBegin}{OdbcConnection}(conn, ...) @@ -79,6 +93,9 @@ or a \linkS4class{DBIResult}} \item{immediate}{If \code{TRUE}, SQLExecDirect will be used instead of SQLPrepare, and the \code{params} argument is ignored} +\item{timeout}{Number of seconds to wait before aborting the query. +The default, \code{Inf}, will never timeout.} + \item{obj}{An R object whose SQL type we want to determine.} \item{x}{A character vector, \link[DBI]{SQL} or \link[DBI]{Id} object to quote as identifier.} diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 907f716b..85357456 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -248,15 +248,16 @@ BEGIN_RCPP END_RCPP } // new_result -result_ptr new_result(connection_ptr const& p, std::string const& sql, const bool immediate); -RcppExport SEXP _odbc_new_result(SEXP pSEXP, SEXP sqlSEXP, SEXP immediateSEXP) { +result_ptr new_result(connection_ptr const& p, std::string const& sql, const bool immediate, long timeout); +RcppExport SEXP _odbc_new_result(SEXP pSEXP, SEXP sqlSEXP, SEXP immediateSEXP, SEXP timeoutSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; Rcpp::traits::input_parameter< connection_ptr const& >::type p(pSEXP); Rcpp::traits::input_parameter< std::string const& >::type sql(sqlSEXP); Rcpp::traits::input_parameter< const bool >::type immediate(immediateSEXP); - rcpp_result_gen = Rcpp::wrap(new_result(p, sql, immediate)); + Rcpp::traits::input_parameter< long >::type timeout(timeoutSEXP); + rcpp_result_gen = Rcpp::wrap(new_result(p, sql, immediate, timeout)); return rcpp_result_gen; END_RCPP } @@ -383,7 +384,7 @@ static const R_CallMethodDef CallEntries[] = { {"_odbc_result_release", (DL_FUNC) &_odbc_result_release, 1}, {"_odbc_result_active", (DL_FUNC) &_odbc_result_active, 1}, {"_odbc_result_completed", (DL_FUNC) &_odbc_result_completed, 1}, - {"_odbc_new_result", (DL_FUNC) &_odbc_new_result, 3}, + {"_odbc_new_result", (DL_FUNC) &_odbc_new_result, 4}, {"_odbc_result_fetch", (DL_FUNC) &_odbc_result_fetch, 2}, {"_odbc_result_column_info", (DL_FUNC) &_odbc_result_column_info, 1}, {"_odbc_result_bind", (DL_FUNC) &_odbc_result_bind, 3}, diff --git a/src/odbc_result.cpp b/src/odbc_result.cpp index d3a882b1..a4b8cbd1 100644 --- a/src/odbc_result.cpp +++ b/src/odbc_result.cpp @@ -7,20 +7,21 @@ namespace odbc { odbc_result::odbc_result( - std::shared_ptr c, std::string sql, bool immediate) + std::shared_ptr c, std::string sql, bool immediate, long timeout) : c_(c), sql_(sql), rows_fetched_(0), num_columns_(0), complete_(0), bound_(false), - output_encoder_(Iconv(c_->encoding(), "UTF-8")) { + output_encoder_(Iconv(c_->encoding(), "UTF-8")), + timeout_(timeout) { if (immediate) { s_ = std::make_shared(); bound_ = true; r_ = std::make_shared( - s_->execute_direct(*c_->connection(), sql_)); + s_->execute_direct(*c_->connection(), sql_, timeout_)); num_columns_ = r_->columns(); c_->set_current_result(this); } else { @@ -42,12 +43,12 @@ std::shared_ptr odbc_result::result() const { return std::shared_ptr(r_); } void odbc_result::prepare() { - s_ = std::make_shared(*c_->connection(), sql_); + s_ = std::make_shared(*c_->connection(), sql_, timeout_); } void odbc_result::execute() { if (!r_) { try { - r_ = std::make_shared(s_->execute()); + r_ = std::make_shared(s_->execute(1L, timeout_)); num_columns_ = r_->columns(); } catch (const nanodbc::database_error& e) { c_->set_current_result(nullptr); @@ -151,7 +152,7 @@ void odbc_result::bind_list( for (short col = 0; col < ncols; ++col) { bind_columns(*s_, types[col], x, col, start, size); } - r_ = std::make_shared(nanodbc::execute(*s_, size)); + r_ = std::make_shared(s_->execute(size, timeout_)); num_columns_ = r_->columns(); start += batch_rows; diff --git a/src/odbc_result.h b/src/odbc_result.h index 44f63553..a79ec43e 100644 --- a/src/odbc_result.h +++ b/src/odbc_result.h @@ -32,7 +32,7 @@ class odbc_error : public Rcpp::exception { class odbc_result { public: odbc_result( - std::shared_ptr c, std::string sql, bool immediate); + std::shared_ptr c, std::string sql, bool immediate, long timeout); std::shared_ptr connection() const; std::shared_ptr statement() const; std::shared_ptr result() const; @@ -63,6 +63,7 @@ class odbc_result { bool complete_; bool bound_; Iconv output_encoder_; + long timeout_; std::map> strings_; std::map>> raws_; diff --git a/src/result.cpp b/src/result.cpp index 9fbca8a7..d95788d8 100644 --- a/src/result.cpp +++ b/src/result.cpp @@ -18,8 +18,8 @@ bool result_completed(result_ptr const& r) { return r->complete(); } // [[Rcpp::export]] result_ptr new_result( - connection_ptr const& p, std::string const& sql, const bool immediate) { - return result_ptr(new odbc::odbc_result(*p, sql, immediate)); + connection_ptr const& p, std::string const& sql, const bool immediate, long timeout) { + return result_ptr(new odbc::odbc_result(*p, sql, immediate, timeout)); } // [[Rcpp::export]] diff --git a/tests/testthat/_snaps/SQLServer.md b/tests/testthat/_snaps/SQLServer.md index 5a232d17..cff9393a 100644 --- a/tests/testthat/_snaps/SQLServer.md +++ b/tests/testthat/_snaps/SQLServer.md @@ -6,3 +6,13 @@ Temporary flag is set to true, but table name doesn't use # prefix +# timeout is respected + + Code + dbGetQuery(con, "WaitFor Delay '00:00:03'; SELECT 1 as x", timeout = 1) + Condition + Error: + ! nanodbc/nanodbc.cpp:1769: 00000 + [Microsoft][ODBC Driver 18 for SQL Server]Query timeout expired + 'WaitFor Delay '00:00:03'; SELECT 1 as x' + diff --git a/tests/testthat/test-SQLServer.R b/tests/testthat/test-SQLServer.R index 40420ec0..01debb36 100644 --- a/tests/testthat/test-SQLServer.R +++ b/tests/testthat/test-SQLServer.R @@ -310,4 +310,15 @@ test_that("SQLServer", { expect_true(grepl("\n", e$message)) }) }) + + test_that("timeout is respected", { + con <- DBItest:::connect(DBItest:::get_default_context()) + res <- dbGetQuery(con, "WaitFor Delay '00:00:01'; SELECT 1 as x", timeout = 2) + expect_equal(res, data.frame(x = 1)) + + expect_snapshot( + dbGetQuery(con, "WaitFor Delay '00:00:03'; SELECT 1 as x", timeout = 1), + error = TRUE + ) + }) })