Skip to content

Commit 1688c0d

Browse files
committed
#19 add support for mariadb datasource, added unit test, added constants file for database product names
1 parent 484cf88 commit 1688c0d

File tree

4 files changed

+75
-8
lines changed

4 files changed

+75
-8
lines changed

pom.xml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,16 @@
180180
<version>4.11</version>
181181
<scope>test</scope>
182182
</dependency>
183+
<dependency>
184+
<groupId>com.zaxxer</groupId>
185+
<artifactId>HikariCP</artifactId>
186+
<version>2.5.1</version>
187+
</dependency>
188+
<dependency>
189+
<groupId>org.mariadb.jdbc</groupId>
190+
<artifactId>mariadb-java-client</artifactId>
191+
<version>3.1.4</version>
192+
</dependency>
183193
<dependency>
184194
<groupId>org.casbin</groupId>
185195
<artifactId>jcasbin</artifactId>

src/main/java/org/casbin/adapter/HibernateAdapter.java

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package org.casbin.adapter;
22

33
import org.apache.commons.collections.CollectionUtils;
4+
import org.casbin.constants.AdapterConstants;
45
import org.casbin.jcasbin.model.Assertion;
56
import org.casbin.jcasbin.model.Model;
67
import org.casbin.jcasbin.persist.Adapter;
@@ -59,10 +60,10 @@ private void open() throws SQLException {
5960
private void createDatabase() {
6061
Session session = factory.openSession();
6162
Transaction tx = session.beginTransaction();
62-
if (this.databaseProductName.contains("MySQL")) {
63+
if (this.databaseProductName.contains(AdapterConstants.MySQL.name()) || this.databaseProductName.contains(AdapterConstants.MariaDB.name())) {
6364
session.createSQLQuery("CREATE DATABASE IF NOT EXISTS casbin").executeUpdate();
6465
session.createSQLQuery("USE casbin").executeUpdate();
65-
} else if (this.databaseProductName.contains("SQLServer")) {
66+
} else if (this.databaseProductName.contains(AdapterConstants.SQLServer.name())) {
6667
session.createSQLQuery("IF NOT EXISTS (" +
6768
"SELECT * FROM sysdatabases WHERE name = 'casbin') CREATE DATABASE casbin ON PRIMARY " +
6869
"( NAME = N'casbin', FILENAME = N'C:\\Program Files\\Microsoft SQL Server\\MSSQL.1\\MSSQL\\DATA\\casbinDB.mdf' , SIZE = 3072KB , MAXSIZE = UNLIMITED, FILEGROWTH = 1024KB ) " +
@@ -78,7 +79,7 @@ private void createDatabase() {
7879
private void createTable() {
7980
Session session = factory.openSession();
8081
Transaction tx = session.beginTransaction();
81-
if (this.databaseProductName.contains("MySQL")) {
82+
if (this.databaseProductName.contains(AdapterConstants.MySQL.name()) || this.databaseProductName.contains(AdapterConstants.MariaDB.name())) {
8283
session.createSQLQuery("CREATE TABLE IF NOT EXISTS casbin_rule (" +
8384
"id INT not NULL primary key," +
8485
"ptype VARCHAR(100) not NULL," +
@@ -88,7 +89,7 @@ private void createTable() {
8889
"v3 VARCHAR(100)," +
8990
"v4 VARCHAR(100)," +
9091
"v5 VARCHAR(100))").executeUpdate();
91-
} else if (this.databaseProductName.contains("Oracle")) {
92+
} else if (this.databaseProductName.contains(AdapterConstants.Oracle.name())) {
9293
session.createSQLQuery("declare " +
9394
"nCount NUMBER;" +
9495
"v_sql LONG;" +
@@ -109,7 +110,7 @@ private void createTable() {
109110
"execute immediate v_sql;" +
110111
"END IF;" +
111112
"end;").executeUpdate();
112-
} else if (this.databaseProductName.contains("SQLServer")) {
113+
} else if (this.databaseProductName.contains(AdapterConstants.SQLServer.name())) {
113114
session.createSQLQuery("if not exists (select * from sysobjects where id = object_id('casbin_rule')) " +
114115
"create table casbin_rule (" +
115116
" id int, " +
@@ -130,9 +131,9 @@ private void createTable() {
130131
private void dropTable() {
131132
Session session = factory.openSession();
132133
Transaction tx = session.beginTransaction();
133-
if (this.databaseProductName.contains("MySQL")) {
134+
if (this.databaseProductName.contains(AdapterConstants.MySQL.name()) || this.databaseProductName.contains(AdapterConstants.MariaDB.name())) {
134135
session.createSQLQuery("DROP TABLE IF EXISTS casbin_rule").executeUpdate();
135-
} else if (this.databaseProductName.contains("Oracle")) {
136+
} else if (this.databaseProductName.contains(AdapterConstants.Oracle.name())) {
136137
session.createSQLQuery("declare " +
137138
"nCount NUMBER;" +
138139
"v_sql LONG;" +
@@ -144,7 +145,7 @@ private void dropTable() {
144145
"execute immediate v_sql;" +
145146
"END IF;" +
146147
"end;").executeUpdate();
147-
} else if (this.databaseProductName.contains("SQLServer")) {
148+
} else if (this.databaseProductName.contains(AdapterConstants.SQLServer.name())) {
148149
session.createSQLQuery("if exists (select * from sysobjects where id = object_id('casbin_rule') drop table casbin_rule").executeUpdate();
149150
}
150151
tx.commit();
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package org.casbin.constants;
2+
3+
public enum AdapterConstants {
4+
// database product names
5+
MySQL,
6+
MariaDB,
7+
Oracle,
8+
SQLServer;
9+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package org.casbin.test;
2+
3+
import static org.junit.Assert.assertEquals;
4+
5+
import java.sql.SQLException;
6+
7+
import org.casbin.adapter.HibernateAdapter;
8+
import org.casbin.jcasbin.main.Enforcer;
9+
import org.junit.Test;
10+
11+
import com.zaxxer.hikari.HikariConfig;
12+
import com.zaxxer.hikari.HikariDataSource;
13+
14+
public class HibernateAdapterDatasourceTest {
15+
16+
private static final String DRIVER = "org.mariadb.jdbc.Driver";
17+
private static final String URL = "jdbc:mariadb://localhost:3306/casbin?serverTimezone=GMT%2B8&useSSL=false&allowPublicKeyRetrieval=true&rewriteBatchedStatements=true";
18+
private static final String USERNAME = "root";
19+
private static final String PASSWORD = "casbin_test";
20+
private static HikariDataSource dataSource;
21+
22+
23+
@Test
24+
public void testInitDBfromMariaDBDatasource() throws SQLException {
25+
26+
HikariConfig config = new HikariConfig();
27+
config.setJdbcUrl(URL);
28+
config.setUsername(USERNAME);
29+
config.setPassword(PASSWORD);
30+
config.setDriverClassName(DRIVER);
31+
config.setMaximumPoolSize(5);
32+
33+
dataSource = new HikariDataSource(config);
34+
Enforcer e = new Enforcer("examples/rbac_with_domains_model.conf", new HibernateAdapter(dataSource));
35+
36+
e.savePolicy(); //clear table
37+
38+
e.addPolicy("admin", "domain1", "data1", "read");
39+
e.addGroupingPolicy("alice", "admin", "domain1");
40+
41+
testDomainEnforce(e, "alice", "domain1", "data1", "read", true);
42+
}
43+
44+
private void testDomainEnforce(Enforcer e, String sub, String dom, String obj, String act, boolean res) {
45+
assertEquals(res, e.enforce(sub, dom, obj, act));
46+
}
47+
}

0 commit comments

Comments
 (0)