diff --git a/R/RcppExports.R b/R/RcppExports.R index e9c98285..c2dde8f7 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -41,6 +41,10 @@ connection_copy_data <- function(con, sql, df) { invisible(.Call(`_RPostgres_connection_copy_data`, con, sql, df)) } +connection_copy_file <- function(con, sql, file) { + invisible(.Call(`_RPostgres_connection_copy_file`, con, sql, file)) +} + connection_wait_for_notify <- function(con, timeout_secs) { .Call(`_RPostgres_connection_wait_for_notify`, con, timeout_secs) } diff --git a/R/tables.R b/R/tables.R index 28e84d7a..28bfd252 100644 --- a/R/tables.R +++ b/R/tables.R @@ -137,6 +137,85 @@ setMethod("dbWriteTable", c("PqConnection", "character", "data.frame"), ) +#' @param header is a logical indicating whether the first data line (but see +#' `skip`) has a header or not. If missing, it value is determined +#' following [read.table()] convention, namely, it is set to TRUE if +#' and only if the first row has one fewer field that the number of columns. +#' @param sep The field separator, defaults to `','`. +#' @param eol The end-of-line delimiter, defaults to `'\n'`. +#' @param skip number of lines to skip before reading the data. Defaults to 0. +#' @param nrows Number of rows to read to determine types. +#' @param colClasses Character vector of R type names, used to override +#' defaults when imputing classes from on-disk file. +#' @param na.strings a character vector of strings which are to be interpreted as NA values. +#' @export +#' @rdname postgres-tables +setMethod("dbWriteTable", c("PqConnection", "character", "character"), + function(conn, name, value, ..., field.types = NULL, overwrite = FALSE, + append = FALSE, header = TRUE, colClasses = NA, row.names = FALSE, + nrows = 50, sep = ",", na.strings = "NA", eol = "\n", skip = 0, temporary = FALSE) { + + if (!is.logical(overwrite) || length(overwrite) != 1L || is.na(overwrite)) { + stopc("`overwrite` must be a logical scalar") + } + if (!is.logical(append) || length(append) != 1L || is.na(append)) { + stopc("`append` must be a logical scalar") + } + if (!is.logical(temporary) || length(temporary) != 1L) { + stopc("`temporary` must be a logical scalar") + } + if (overwrite && append) { + stopc("overwrite and append cannot both be TRUE") + } + if (!is.null(field.types) && !(is.character(field.types) && !is.null(names(field.types)) && !anyDuplicated(names(field.types)))) { + stopc("`field.types` must be a named character vector with unique names, or NULL") + } + if (append && !is.null(field.types)) { + stopc("Cannot specify `field.types` with `append = TRUE`") + } + + found <- dbExistsTable(conn, name) + if (found && !overwrite && !append) { + stop("Table ", name, " exists in database, and both overwrite and", + " append are FALSE", call. = FALSE) + } + if (found && overwrite) { + dbRemoveTable(conn, name) + } + + if (!found || overwrite) { + if (is.null(field.types)) { + tmp_value <- utils::read.table( + value, sep = sep, header = header, skip = skip, nrows = nrows, + na.strings = na.strings, comment.char = "", colClasses = colClasses, + stringsAsFactors = FALSE) + field.types <- lapply(tmp_value, dbDataType, dbObj = conn) + } + + dbCreateTable( + conn = conn, + name = name, + fields = field.types, + temporary = temporary + ) + } + + value <- path.expand(value) + fields <- dbQuoteIdentifier(conn, names(field.types)) + + skip <- skip + as.integer(header) + sql <- paste0( + "COPY ", dbQuoteIdentifier(conn, name), + " (", paste(fields, collapse = ","), ") ", + "FROM STDIN ", "(FORMAT CSV, DELIMITER '", sep, "', HEADER '", header, "')" + ) + + connection_copy_file(conn@ptr, sql, value) + + invisible(TRUE) + } +) + #' @export #' @inheritParams DBI::sqlRownamesToColumn #' @param ... Ignored. diff --git a/man/postgres-tables.Rd b/man/postgres-tables.Rd index 0ffef90a..361b4b9c 100644 --- a/man/postgres-tables.Rd +++ b/man/postgres-tables.Rd @@ -3,6 +3,7 @@ \name{postgres-tables} \alias{postgres-tables} \alias{dbWriteTable,PqConnection,character,data.frame-method} +\alias{dbWriteTable,PqConnection,character,character-method} \alias{sqlData,PqConnection-method} \alias{dbAppendTable,PqConnection-method} \alias{dbReadTable,PqConnection,character-method} @@ -28,6 +29,25 @@ copy = NULL ) +\S4method{dbWriteTable}{PqConnection,character,character}( + conn, + name, + value, + ..., + field.types = NULL, + overwrite = FALSE, + append = FALSE, + header = TRUE, + colClasses = NA, + row.names = FALSE, + nrows = 50, + sep = ",", + na.strings = "NA", + eol = "\\n", + skip = 0, + temporary = FALSE +) + \S4method{sqlData}{PqConnection}(con, value, row.names = FALSE, ...) \S4method{dbAppendTable}{PqConnection}(conn, name, value, copy = NULL, ..., row.names = NULL) @@ -93,6 +113,24 @@ a single SQL string. This is slower, but always supported. The default maps to \code{TRUE} on connections established via \code{\link[=Postgres]{Postgres()}} and to \code{FALSE} on connections established via \code{\link[=Redshift]{Redshift()}}.} +\item{header}{is a logical indicating whether the first data line (but see +\code{skip}) has a header or not. If missing, it value is determined +following \code{\link[=read.table]{read.table()}} convention, namely, it is set to TRUE if +and only if the first row has one fewer field that the number of columns.} + +\item{colClasses}{Character vector of R type names, used to override +defaults when imputing classes from on-disk file.} + +\item{nrows}{Number of rows to read to determine types.} + +\item{sep}{The field separator, defaults to \code{','}.} + +\item{na.strings}{a character vector of strings which are to be interpreted as NA values.} + +\item{eol}{The end-of-line delimiter, defaults to \code{'\\n'}.} + +\item{skip}{number of lines to skip before reading the data. Defaults to 0.} + \item{con}{A database connection.} \item{check.names}{If \code{TRUE}, the default, column names will be diff --git a/src/DbConnection.cpp b/src/DbConnection.cpp index 0da99e0a..139b67bf 100644 --- a/src/DbConnection.cpp +++ b/src/DbConnection.cpp @@ -1,4 +1,5 @@ #include "pch.h" +#include #include "DbConnection.h" #include "encode.h" #include "DbResult.h" @@ -164,6 +165,49 @@ void DbConnection::copy_data(std::string sql, List df) { PQclear(pComplete); } +void DbConnection::copy_csv(std::string sql, std::string file) { + LOG_DEBUG << sql; + + if (file.size() == 0) + return; + + PGresult* pInit = PQexec(pConn_, sql.c_str()); + if (PQresultStatus(pInit) != PGRES_COPY_IN) { + PQclear(pInit); + conn_stop("Failed to initialise COPY"); + } + PQclear(pInit); + + + const size_t buffer_size = 1024 * 64; + std::string buffer; + buffer.reserve(buffer_size); + + std::ifstream fs(file.c_str(), std::ios::in); + if (!fs.is_open()) { + stop("Can not open file '%s'.", file); + } + + while (!fs.eof()) { + buffer.clear(); + fs.read(&buffer[0], buffer_size); + if (PQputCopyData(pConn_, buffer.data(), static_cast(fs.gcount())) != 1) { + conn_stop("Failed to put data"); + } + } + + if (PQputCopyEnd(pConn_, NULL) != 1) { + conn_stop("Failed to finish COPY"); + } + + PGresult* pComplete = PQgetResult(pConn_); + if (PQresultStatus(pComplete) != PGRES_COMMAND_OK) { + PQclear(pComplete); + conn_stop("COPY returned error"); + } + PQclear(pComplete); +} + void DbConnection::check_connection() { if (!pConn_) { stop("Disconnected"); diff --git a/src/DbConnection.h b/src/DbConnection.h index de80bac8..217241c7 100644 --- a/src/DbConnection.h +++ b/src/DbConnection.h @@ -35,6 +35,8 @@ class DbConnection : boost::noncopyable { void copy_data(std::string sql, List df); + void copy_csv(std::string sql, std::string file); + void check_connection(); List info(); diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index f0ad7b62..78841e81 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -124,6 +124,17 @@ BEGIN_RCPP return R_NilValue; END_RCPP } +// connection_copy_file +void connection_copy_file(DbConnection* con, std::string sql, std::string file); +RcppExport SEXP _RPostgres_connection_copy_file(SEXP conSEXP, SEXP sqlSEXP, SEXP fileSEXP) { +BEGIN_RCPP + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< DbConnection* >::type con(conSEXP); + Rcpp::traits::input_parameter< std::string >::type sql(sqlSEXP); + Rcpp::traits::input_parameter< std::string >::type file(fileSEXP); + connection_copy_file(con, sql, file); + return R_NilValue; +} // connection_wait_for_notify List connection_wait_for_notify(DbConnection* con, int timeout_secs); RcppExport SEXP _RPostgres_connection_wait_for_notify(SEXP conSEXP, SEXP timeout_secsSEXP) { @@ -293,6 +304,7 @@ static const R_CallMethodDef CallEntries[] = { {"_RPostgres_connection_is_transacting", (DL_FUNC) &_RPostgres_connection_is_transacting, 1}, {"_RPostgres_connection_set_transacting", (DL_FUNC) &_RPostgres_connection_set_transacting, 2}, {"_RPostgres_connection_copy_data", (DL_FUNC) &_RPostgres_connection_copy_data, 3}, + {"_RPostgres_connection_copy_file", (DL_FUNC) &_RPostgres_connection_copy_file, 3}, {"_RPostgres_connection_wait_for_notify", (DL_FUNC) &_RPostgres_connection_wait_for_notify, 2}, {"_RPostgres_encode_vector", (DL_FUNC) &_RPostgres_encode_vector, 1}, {"_RPostgres_encode_data_frame", (DL_FUNC) &_RPostgres_encode_data_frame, 1}, diff --git a/src/connection.cpp b/src/connection.cpp index dc547fd2..f239dec3 100644 --- a/src/connection.cpp +++ b/src/connection.cpp @@ -100,6 +100,10 @@ void connection_copy_data(DbConnection* con, std::string sql, List df) { } // [[Rcpp::export]] +void connection_copy_file(DbConnection* con, std::string sql, std::string file) { + return con->copy_csv(sql, file); +} + List connection_wait_for_notify(DbConnection* con, int timeout_secs) { return con->wait_for_notify(timeout_secs); } diff --git a/tests/testthat/test-dbWriteTable.R b/tests/testthat/test-dbWriteTable.R index 583f4c16..bf6b9201 100644 --- a/tests/testthat/test-dbWriteTable.R +++ b/tests/testthat/test-dbWriteTable.R @@ -119,6 +119,16 @@ with_database_connection({ expect_equal(dbGetQuery(con, "SELECT * FROM xy"), value) }) }) + + test_that("Writing CSV to the database", { + with_table(con, "iris", { + tmp <- tempfile() + iris2 <- transform(iris, Species = as.character(Species)) + write.csv(iris2, tmp, row.names = FALSE) + dbWriteTable(con, "iris", tmp, temporary = TRUE) + expect_equal(dbReadTable(con, "iris"), iris2) + }) + }) }) describe("Inf values", {