diff --git a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt index 583ced2f93..2b6d0e1b63 100644 --- a/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt +++ b/dataframe-jdbc/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readJdbc.kt @@ -4,6 +4,7 @@ import io.github.oshai.kotlinlogging.KotlinLogging import org.jetbrains.kotlinx.dataframe.AnyFrame import org.jetbrains.kotlinx.dataframe.DataColumn import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.api.Infer import org.jetbrains.kotlinx.dataframe.api.toDataFrame import org.jetbrains.kotlinx.dataframe.impl.schema.DataFrameSchemaImpl import org.jetbrains.kotlinx.dataframe.io.db.DbType @@ -105,15 +106,17 @@ public data class DatabaseConfiguration(val url: String, val user: String = "", * @param [dbConfig] the configuration for the database, including URL, user, and password. * @param [tableName] the name of the table to read data from. * @param [limit] the maximum number of rows to retrieve from the table. + * @param [inferNullability] indicates how the column nullability should be inferred. * @return the DataFrame containing the data from the SQL table. */ public fun DataFrame.Companion.readSqlTable( dbConfig: DatabaseConfiguration, tableName: String, - limit: Int = DEFAULT_LIMIT + limit: Int = DEFAULT_LIMIT, + inferNullability: Boolean = true, ): AnyFrame { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> - return readSqlTable(connection, tableName, limit) + return readSqlTable(connection, tableName, limit, inferNullability) } } @@ -123,6 +126,7 @@ public fun DataFrame.Companion.readSqlTable( * @param [connection] the database connection to read tables from. * @param [tableName] the name of the table to read data from. * @param [limit] the maximum number of rows to retrieve from the table. + * @param [inferNullability] indicates how the column nullability should be inferred. * @return the DataFrame containing the data from the SQL table. * * @see DriverManager.getConnection @@ -130,7 +134,8 @@ public fun DataFrame.Companion.readSqlTable( public fun DataFrame.Companion.readSqlTable( connection: Connection, tableName: String, - limit: Int = DEFAULT_LIMIT + limit: Int = DEFAULT_LIMIT, + inferNullability: Boolean = true, ): AnyFrame { var preparedQuery = "SELECT * FROM $tableName" if (limit > 0) preparedQuery += " LIMIT $limit" @@ -145,7 +150,7 @@ public fun DataFrame.Companion.readSqlTable( preparedQuery ).use { rs -> val tableColumns = getTableColumnsMetadata(rs) - return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit) + return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit, inferNullability) } } } @@ -159,15 +164,17 @@ public fun DataFrame.Companion.readSqlTable( * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. * @param [sqlQuery] the SQL query to execute. * @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution. + * @param [inferNullability] indicates how the column nullability should be inferred. * @return the DataFrame containing the result of the SQL query. */ public fun DataFrame.Companion.readSqlQuery( dbConfig: DatabaseConfiguration, sqlQuery: String, - limit: Int = DEFAULT_LIMIT + limit: Int = DEFAULT_LIMIT, + inferNullability: Boolean = true, ): AnyFrame { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> - return readSqlQuery(connection, sqlQuery, limit) + return readSqlQuery(connection, sqlQuery, limit, inferNullability) } } @@ -180,6 +187,7 @@ public fun DataFrame.Companion.readSqlQuery( * @param [connection] the database connection to execute the SQL query. * @param [sqlQuery] the SQL query to execute. * @param [limit] the maximum number of rows to retrieve from the result of the SQL query execution. + * @param [inferNullability] indicates how the column nullability should be inferred. * @return the DataFrame containing the result of the SQL query. * * @see DriverManager.getConnection @@ -187,9 +195,13 @@ public fun DataFrame.Companion.readSqlQuery( public fun DataFrame.Companion.readSqlQuery( connection: Connection, sqlQuery: String, - limit: Int = DEFAULT_LIMIT + limit: Int = DEFAULT_LIMIT, + inferNullability: Boolean = true, ): AnyFrame { - require(isValid(sqlQuery)) { "SQL query should start from SELECT and contain one query for reading data without any manipulation. " } + require(isValid(sqlQuery)) { + "SQL query should start from SELECT and contain one query for reading data without any manipulation. " + + "Also it should not contain any separators like `;`." + } val url = connection.metaData.url val dbType = extractDBTypeFromUrl(url) @@ -202,12 +214,12 @@ public fun DataFrame.Companion.readSqlQuery( connection.createStatement().use { st -> st.executeQuery(internalSqlQuery).use { rs -> val tableColumns = getTableColumnsMetadata(rs) - return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, DEFAULT_LIMIT) + return fetchAndConvertDataFromResultSet(tableColumns, rs, dbType, limit, inferNullability) } } } -/** SQL-query is accepted only if it starts from SELECT */ +/** SQL query is accepted only if it starts from SELECT */ private fun isValid(sqlQuery: String): Boolean { val normalizedSqlQuery = sqlQuery.trim().uppercase() @@ -221,15 +233,17 @@ private fun isValid(sqlQuery: String): Boolean { * @param [resultSet] the [ResultSet] containing the data to read. * @param [dbType] the type of database that the [ResultSet] belongs to. * @param [limit] the maximum number of rows to read from the [ResultSet]. + * @param [inferNullability] indicates how the column nullability should be inferred. * @return the DataFrame generated from the [ResultSet] data. */ public fun DataFrame.Companion.readResultSet( resultSet: ResultSet, dbType: DbType, - limit: Int = DEFAULT_LIMIT + limit: Int = DEFAULT_LIMIT, + inferNullability: Boolean = true, ): AnyFrame { val tableColumns = getTableColumnsMetadata(resultSet) - return fetchAndConvertDataFromResultSet(tableColumns, resultSet, dbType, limit) + return fetchAndConvertDataFromResultSet(tableColumns, resultSet, dbType, limit, inferNullability) } /** @@ -238,17 +252,19 @@ public fun DataFrame.Companion.readResultSet( * @param [resultSet] the [ResultSet] containing the data to read. * @param [connection] the connection to the database (it's required to extract the database type). * @param [limit] the maximum number of rows to read from the [ResultSet]. + * @param [inferNullability] indicates how the column nullability should be inferred. * @return the DataFrame generated from the [ResultSet] data. */ public fun DataFrame.Companion.readResultSet( resultSet: ResultSet, connection: Connection, - limit: Int = DEFAULT_LIMIT + limit: Int = DEFAULT_LIMIT, + inferNullability: Boolean = true, ): AnyFrame { val url = connection.metaData.url val dbType = extractDBTypeFromUrl(url) - return readResultSet(resultSet, dbType, limit) + return readResultSet(resultSet, dbType, limit, inferNullability) } /** @@ -256,15 +272,18 @@ public fun DataFrame.Companion.readResultSet( * * @param [dbConfig] the database configuration to connect to the database, including URL, user, and password. * @param [limit] the maximum number of rows to read from each table. + * @param [catalogue] a name of the catalog from which tables will be retrieved. A null value retrieves tables from all catalogs. + * @param [inferNullability] indicates how the column nullability should be inferred. * @return a list of [AnyFrame] objects representing the non-system tables from the database. */ public fun DataFrame.Companion.readAllSqlTables( dbConfig: DatabaseConfiguration, catalogue: String? = null, - limit: Int = DEFAULT_LIMIT + limit: Int = DEFAULT_LIMIT, + inferNullability: Boolean = true, ): List { DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection -> - return readAllSqlTables(connection, catalogue, limit) + return readAllSqlTables(connection, catalogue, limit, inferNullability) } } @@ -273,6 +292,8 @@ public fun DataFrame.Companion.readAllSqlTables( * * @param [connection] the database connection to read tables from. * @param [limit] the maximum number of rows to read from each table. + * @param [catalogue] a name of the catalog from which tables will be retrieved. A null value retrieves tables from all catalogs. + * @param [inferNullability] indicates how the column nullability should be inferred. * @return a list of [AnyFrame] objects representing the non-system tables from the database. * * @see DriverManager.getConnection @@ -280,7 +301,8 @@ public fun DataFrame.Companion.readAllSqlTables( public fun DataFrame.Companion.readAllSqlTables( connection: Connection, catalogue: String? = null, - limit: Int = DEFAULT_LIMIT + limit: Int = DEFAULT_LIMIT, + inferNullability: Boolean = true, ): List { val metaData = connection.metaData val url = connection.metaData.url @@ -304,7 +326,7 @@ public fun DataFrame.Companion.readAllSqlTables( // could be Dialect/Database specific logger.debug { "Reading table: $tableName" } - val dataFrame = readSqlTable(connection, tableName, limit) + val dataFrame = readSqlTable(connection, tableName, limit, inferNullability) dataFrames += dataFrame logger.debug { "Finished reading table: $tableName" } } @@ -450,7 +472,7 @@ public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): val dbType = extractDBTypeFromUrl(url) val tableTypes = arrayOf("TABLE") - // exclude system and other tables without data + // exclude a system and other tables without data val tables = metaData.getTables(null, null, null, tableTypes) val dataFrameSchemas = mutableListOf() @@ -561,13 +583,15 @@ private fun manageColumnNameDuplication(columnNameCounter: MutableMap, rs: ResultSet, dbType: DbType, - limit: Int + limit: Int, + inferNullability: Boolean, ): AnyFrame { val data = List(tableColumns.size) { mutableListOf() } @@ -596,6 +620,7 @@ private fun fetchAndConvertDataFromResultSet( DataColumn.createValueColumn( name = tableColumns[index].name, values = values, + infer = convertNullabilityInference(inferNullability), type = kotlinTypesForSqlColumns[index]!! ) }.toDataFrame() @@ -605,6 +630,8 @@ private fun fetchAndConvertDataFromResultSet( return dataFrame } +private fun convertNullabilityInference(inferNullability: Boolean) = if (inferNullability) Infer.Nulls else Infer.None + private fun extractNewRowFromResultSetAndAddToData( tableColumns: MutableList, data: List>, diff --git a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt index b9d753d08a..864a0c4dad 100644 --- a/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt +++ b/dataframe-jdbc/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/h2Test.kt @@ -6,10 +6,7 @@ import org.h2.jdbc.JdbcSQLSyntaxErrorException import org.intellij.lang.annotations.Language import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.annotations.DataSchema -import org.jetbrains.kotlinx.dataframe.api.add -import org.jetbrains.kotlinx.dataframe.api.cast -import org.jetbrains.kotlinx.dataframe.api.filter -import org.jetbrains.kotlinx.dataframe.api.select +import org.jetbrains.kotlinx.dataframe.api.* import org.jetbrains.kotlinx.dataframe.io.db.H2 import org.junit.AfterClass import org.junit.BeforeClass @@ -677,4 +674,115 @@ class JdbcTest { saleDataSchema1.columns.size shouldBe 3 saleDataSchema1.columns["amount"]!!.type shouldBe typeOf() } + + @Test + fun `infer nullability`() { + // prepare tables and data + @Language("SQL") + val createTestTable1Query = """ + CREATE TABLE TestTable1 ( + id INT PRIMARY KEY, + name VARCHAR(50), + surname VARCHAR(50), + age INT NOT NULL + ) + """ + + connection.createStatement().execute(createTestTable1Query) + + connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (1, 'John', 'Crawford', 40)") + connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (2, 'Alice', 'Smith', 25)") + connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (3, 'Bob', 'Johnson', 47)") + connection.createStatement().execute("INSERT INTO TestTable1 (id, name, surname, age) VALUES (4, 'Sam', NULL, 15)") + + // start testing `readSqlTable` method + + // with default inferNullability: Boolean = true + val tableName = "TestTable1" + val df = DataFrame.readSqlTable(connection, tableName) + df.schema().columns["id"]!!.type shouldBe typeOf() + df.schema().columns["name"]!!.type shouldBe typeOf() + df.schema().columns["surname"]!!.type shouldBe typeOf() + df.schema().columns["age"]!!.type shouldBe typeOf() + + val dataSchema = DataFrame.getSchemaForSqlTable(connection, tableName) + dataSchema.columns.size shouldBe 4 + dataSchema.columns["id"]!!.type shouldBe typeOf() + dataSchema.columns["name"]!!.type shouldBe typeOf() + dataSchema.columns["surname"]!!.type shouldBe typeOf() + dataSchema.columns["age"]!!.type shouldBe typeOf() + + // with inferNullability: Boolean = false + val df1 = DataFrame.readSqlTable(connection, tableName, inferNullability = false) + df1.schema().columns["id"]!!.type shouldBe typeOf() + df1.schema().columns["name"]!!.type shouldBe typeOf() // <=== this column changed a type because it doesn't contain nulls + df1.schema().columns["surname"]!!.type shouldBe typeOf() + df1.schema().columns["age"]!!.type shouldBe typeOf() + + // end testing `readSqlTable` method + + // start testing `readSQLQuery` method + + // ith default inferNullability: Boolean = true + @Language("SQL") + val sqlQuery = """ + SELECT name, surname, age FROM TestTable1 + """.trimIndent() + + val df2 = DataFrame.readSqlQuery(connection, sqlQuery) + df2.schema().columns["name"]!!.type shouldBe typeOf() + df2.schema().columns["surname"]!!.type shouldBe typeOf() + df2.schema().columns["age"]!!.type shouldBe typeOf() + + val dataSchema2 = DataFrame.getSchemaForSqlQuery(connection, sqlQuery) + dataSchema2.columns.size shouldBe 3 + dataSchema2.columns["name"]!!.type shouldBe typeOf() + dataSchema2.columns["surname"]!!.type shouldBe typeOf() + dataSchema2.columns["age"]!!.type shouldBe typeOf() + + // with inferNullability: Boolean = false + val df3 = DataFrame.readSqlQuery(connection, sqlQuery, inferNullability = false) + df3.schema().columns["name"]!!.type shouldBe typeOf() // <=== this column changed a type because it doesn't contain nulls + df3.schema().columns["surname"]!!.type shouldBe typeOf() + df3.schema().columns["age"]!!.type shouldBe typeOf() + + // end testing `readSQLQuery` method + + // start testing `readResultSet` method + + connection.createStatement(ResultSet.TYPE_SCROLL_SENSITIVE, ResultSet.CONCUR_UPDATABLE).use { st -> + @Language("SQL") + val selectStatement = "SELECT * FROM TestTable1" + + st.executeQuery(selectStatement).use { rs -> + // ith default inferNullability: Boolean = true + val df4 = DataFrame.readResultSet(rs, H2) + df4.schema().columns["id"]!!.type shouldBe typeOf() + df4.schema().columns["name"]!!.type shouldBe typeOf() + df4.schema().columns["surname"]!!.type shouldBe typeOf() + df4.schema().columns["age"]!!.type shouldBe typeOf() + + rs.beforeFirst() + + val dataSchema3 = DataFrame.getSchemaForResultSet(rs, H2) + dataSchema3.columns.size shouldBe 4 + dataSchema3.columns["id"]!!.type shouldBe typeOf() + dataSchema3.columns["name"]!!.type shouldBe typeOf() + dataSchema3.columns["surname"]!!.type shouldBe typeOf() + dataSchema3.columns["age"]!!.type shouldBe typeOf() + + // with inferNullability: Boolean = false + rs.beforeFirst() + + val df5 = DataFrame.readResultSet(rs, H2, inferNullability = false) + df5.schema().columns["id"]!!.type shouldBe typeOf() + df5.schema().columns["name"]!!.type shouldBe typeOf() // <=== this column changed a type because it doesn't contain nulls + df5.schema().columns["surname"]!!.type shouldBe typeOf() + df5.schema().columns["age"]!!.type shouldBe typeOf() + } + } + // end testing `readResultSet` method + + connection.createStatement().execute("DROP TABLE TestTable1") + } }