Skip to content

Commit c1ecd65

Browse files
committed
[postgres]Improve pattern matching for prepared
1 parent 5022e54 commit c1ecd65

File tree

2 files changed

+55
-11
lines changed

2 files changed

+55
-11
lines changed

spark-postgres/src/main/scala/io/frama/parisni/spark/postgres/PGTool.scala

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -392,17 +392,14 @@ object PGTool extends java.io.Serializable with LazyLogging {
392392

393393
def parametrize(st: PreparedStatement, params: List[Any]) = {
394394
for ((obj, i) <- params.zipWithIndex) {
395-
obj.getClass.getCanonicalName match {
396-
case "java.lang.String" => st.setString(i + 1, obj.asInstanceOf[String])
397-
case "java.lang.Boolean" =>
398-
st.setBoolean(i + 1, obj.asInstanceOf[Boolean])
399-
case "java.lang.Long" => st.setLong(i + 1, obj.asInstanceOf[Long])
400-
case "java.lang.Integer" => st.setInt(i + 1, obj.asInstanceOf[Int])
401-
case "java.math.BigDecimal" =>
402-
st.setDouble(i + 1, obj.asInstanceOf[Double])
403-
case "java.sql.Date" => st.setDate(i + 1, obj.asInstanceOf[Date])
404-
case "java.sql.Timestamp" =>
405-
st.setTimestamp(i + 1, obj.asInstanceOf[Timestamp])
395+
obj match {
396+
case s: String => st.setString(i + 1, s)
397+
case b: Boolean => st.setBoolean(i + 1, b)
398+
case l: Long => st.setLong(i + 1, l)
399+
case i: Integer => st.setInt(i + 1, i)
400+
case b: java.math.BigDecimal => st.setDouble(i + 1, b.doubleValue())
401+
case d: java.sql.Date => st.setDate(i + 1, d)
402+
case t: Timestamp => st.setTimestamp(i + 1, t)
406403
case _ =>
407404
throw new UnsupportedEncodingException(
408405
obj.getClass.getCanonicalName + " type not yet supported for prepared statements")
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package io.frama.parisni.spark.postgres
2+
3+
import org.apache.spark.sql.{DataFrame, QueryTest}
4+
import org.junit.Test
5+
6+
class CrudTest extends QueryTest with SparkSessionTestWrapper {
7+
8+
@Test
9+
def verifyParametrised(): Unit = {
10+
import spark.implicits._
11+
12+
val df: DataFrame = (("bob", 2, true) ::
13+
Nil).toDF("colstring", "colint", "colboolean")
14+
getPgTool().tableCreate("test_crud", df.schema)
15+
val query =
16+
"insert into test_crud (colstring, colint, colboolean) values (?, ? ,?)"
17+
getPgTool().sqlExec(query, List("bob", 1, false))
18+
val output = spark.read
19+
.format("io.frama.parisni.spark.postgres")
20+
.option("host", "localhost")
21+
.option("port", pg.getEmbeddedPostgres.getPort)
22+
.option("database", "postgres")
23+
.option("user", "postgres")
24+
.option("query", "select * from test_crud")
25+
.load
26+
val wanted =
27+
(("bob", 1, false) :: Nil).toDF("colstring", "colint", "colboolean")
28+
checkAnswer(output, wanted)
29+
30+
}
31+
32+
@Test
33+
def verifyParametrisedWithResult(): Unit = {
34+
import spark.implicits._
35+
36+
val df: DataFrame = (("bob", 2, true) ::
37+
Nil).toDF("colstring", "colint", "colboolean")
38+
getPgTool().tableCreate("test_crud2", df.schema)
39+
val query =
40+
"insert into test_crud2 (colstring, colint, colboolean) values (?, ? ,?) returning colstring, colint, colboolean"
41+
val result = getPgTool().sqlExecWithResult(query, List("bob", 1, false))
42+
val wanted =
43+
(("bob", 1, false) :: Nil).toDF("colstring", "colint", "colboolean")
44+
checkAnswer(result, wanted)
45+
}
46+
47+
}

0 commit comments

Comments
 (0)