Skip to content

Commit adc99f9

Browse files
authored
feat: Support mariadb datasource (#21)
1 parent 484cf88 commit adc99f9

File tree

3 files changed

+47
-3
lines changed

3 files changed

+47
-3
lines changed

pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,12 @@
190190
<artifactId>hibernate-core</artifactId>
191191
<version>5.3.7.Final</version>
192192
</dependency>
193+
<dependency>
194+
<groupId>org.mariadb.jdbc</groupId>
195+
<artifactId>mariadb-java-client</artifactId>
196+
<version>3.1.4</version>
197+
<scope>test</scope>
198+
</dependency>
193199
<dependency>
194200
<groupId>mysql</groupId>
195201
<artifactId>mysql-connector-java</artifactId>

src/main/java/org/casbin/adapter/HibernateAdapter.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ private void open() throws SQLException {
5959
private void createDatabase() {
6060
Session session = factory.openSession();
6161
Transaction tx = session.beginTransaction();
62-
if (this.databaseProductName.contains("MySQL")) {
62+
if (this.databaseProductName.contains("MySQL") || this.databaseProductName.contains("MariaDB")) {
6363
session.createSQLQuery("CREATE DATABASE IF NOT EXISTS casbin").executeUpdate();
6464
session.createSQLQuery("USE casbin").executeUpdate();
6565
} else if (this.databaseProductName.contains("SQLServer")) {
@@ -78,7 +78,7 @@ private void createDatabase() {
7878
private void createTable() {
7979
Session session = factory.openSession();
8080
Transaction tx = session.beginTransaction();
81-
if (this.databaseProductName.contains("MySQL")) {
81+
if (this.databaseProductName.contains("MySQL") || this.databaseProductName.contains("MariaDB")) {
8282
session.createSQLQuery("CREATE TABLE IF NOT EXISTS casbin_rule (" +
8383
"id INT not NULL primary key," +
8484
"ptype VARCHAR(100) not NULL," +
@@ -130,7 +130,7 @@ private void createTable() {
130130
private void dropTable() {
131131
Session session = factory.openSession();
132132
Transaction tx = session.beginTransaction();
133-
if (this.databaseProductName.contains("MySQL")) {
133+
if (this.databaseProductName.contains("MySQL") || this.databaseProductName.contains("MariaDB")) {
134134
session.createSQLQuery("DROP TABLE IF EXISTS casbin_rule").executeUpdate();
135135
} else if (this.databaseProductName.contains("Oracle")) {
136136
session.createSQLQuery("declare " +
@@ -356,6 +356,8 @@ private void setDatabaseProductName() {
356356
this.databaseProductName = "Oracle";
357357
} else if (this.driver.contains("sqlserver")) {
358358
this.databaseProductName = "SQLServer";
359+
} else if (this.driver.contains("mariadb")) {
360+
this.databaseProductName = "MariaDB";
359361
}
360362
}
361363
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package org.casbin.test;
2+
3+
import static org.junit.Assert.assertEquals;
4+
import java.sql.SQLException;
5+
import org.casbin.adapter.HibernateAdapter;
6+
import org.casbin.jcasbin.main.Enforcer;
7+
import org.junit.Test;
8+
import org.mariadb.jdbc.MariaDbDataSource;
9+
10+
public class HibernateAdapterDatasourceTest {
11+
12+
private static final String URL = "jdbc:mariadb://localhost:3306/casbin?serverTimezone=GMT%2B8&useSSL=false&allowPublicKeyRetrieval=true&rewriteBatchedStatements=true";
13+
private static final String USERNAME = "root";
14+
private static final String PASSWORD = "casbin_test";
15+
16+
@Test
17+
public void testInitDBfromMariaDBDatasource() throws SQLException {
18+
MariaDbDataSource mariaDbDataSource = new MariaDbDataSource();
19+
mariaDbDataSource.setUrl(URL);
20+
mariaDbDataSource.setUser(USERNAME);
21+
mariaDbDataSource.setPassword(PASSWORD);
22+
23+
Enforcer e = new Enforcer("examples/rbac_with_domains_model.conf", new HibernateAdapter(mariaDbDataSource));
24+
25+
e.savePolicy(); //clear table
26+
27+
e.addPolicy("admin", "domain1", "data1", "read");
28+
e.addGroupingPolicy("alice", "admin", "domain1");
29+
30+
testDomainEnforce(e, "alice", "domain1", "data1", "read", true);
31+
}
32+
33+
private void testDomainEnforce(Enforcer e, String sub, String dom, String obj, String act, boolean res) {
34+
assertEquals(res, e.enforce(sub, dom, obj, act));
35+
}
36+
}

0 commit comments

Comments
 (0)