diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index a785839..b0a95d4 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -118,3 +118,24 @@ jobs: ./start-connect-server.sh cd ../.. swift test --no-parallel + + integration-test-mac-spark3: + runs-on: macos-15 + steps: + - uses: actions/checkout@v4 + - uses: swift-actions/setup-swift@v2.3.0 + with: + swift-version: "6.1" + - name: Install Java + uses: actions/setup-java@v4 + with: + distribution: zulu + java-version: 17 + - name: Test + run: | + curl -LO https://downloads.apache.org/spark/spark-3.5.5/spark-3.5.5-bin-hadoop3.tgz + tar xvfz spark-3.5.5-bin-hadoop3.tgz + cd spark-3.5.5-bin-hadoop3/sbin + ./start-connect-server.sh --packages org.apache.spark:spark-connect_2.12:3.5.5 + cd ../.. + swift test --no-parallel diff --git a/Tests/SparkConnectTests/DataFrameInternalTests.swift b/Tests/SparkConnectTests/DataFrameInternalTests.swift index 49814aa..d557400 100644 --- a/Tests/SparkConnectTests/DataFrameInternalTests.swift +++ b/Tests/SparkConnectTests/DataFrameInternalTests.swift @@ -32,7 +32,7 @@ struct DataFrameInternalTests { #expect(rows.count == 1) #expect(rows[0].length == 1) #expect( - try rows[0].get(0) as! String == """ + try (rows[0].get(0) as! String).trimmingCharacters(in: .whitespacesAndNewlines) == """ +---+ |id | +---+ @@ -73,7 +73,7 @@ struct DataFrameInternalTests { #expect(rows[0].length == 1) print(try rows[0].get(0) as! String) #expect( - try rows[0].get(0) as! String == """ + try (rows[0].get(0) as! String).trimmingCharacters(in: .whitespacesAndNewlines) == """ -RECORD 0-- id | 0 -RECORD 1-- diff --git a/Tests/SparkConnectTests/DataFrameReaderTests.swift b/Tests/SparkConnectTests/DataFrameReaderTests.swift index f1049b1..1f71d0a 100644 --- a/Tests/SparkConnectTests/DataFrameReaderTests.swift +++ b/Tests/SparkConnectTests/DataFrameReaderTests.swift @@ -48,10 +48,12 @@ struct DataFrameReaderTests { @Test func xml() async throws { let spark = try await SparkSession.builder.getOrCreate() - let path = "../examples/src/main/resources/people.xml" - #expect(try await spark.read.option("rowTag", "person").format("xml").load(path).count() == 3) - #expect(try await spark.read.option("rowTag", "person").xml(path).count() == 3) - #expect(try await spark.read.option("rowTag", "person").xml(path, path).count() == 6) + if await spark.version >= "4.0.0" { + let path = "../examples/src/main/resources/people.xml" + #expect(try await spark.read.option("rowTag", "person").format("xml").load(path).count() == 3) + #expect(try await spark.read.option("rowTag", "person").xml(path).count() == 3) + #expect(try await spark.read.option("rowTag", "person").xml(path, path).count() == 6) + } await spark.stop() } @@ -80,7 +82,7 @@ struct DataFrameReaderTests { let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") let spark = try await SparkSession.builder.getOrCreate() try await SQLHelper.withTable(spark, tableName)({ - _ = try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), (2)").count() + _ = try await spark.sql("CREATE TABLE \(tableName) USING ORC AS VALUES (1), (2)").count() #expect(try await spark.read.table(tableName).count() == 2) }) await spark.stop() diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index f8673e5..e39e83b 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -69,16 +69,20 @@ struct DataFrameTests { let spark = try await SparkSession.builder.getOrCreate() let schema1 = try await spark.sql("SELECT 'a' as col1").schema - #expect( - schema1 - == #"{"struct":{"fields":[{"name":"col1","dataType":{"string":{"collation":"UTF8_BINARY"}}}]}}"# - ) + let answer1 = if await spark.version.starts(with: "4.") { + #"{"struct":{"fields":[{"name":"col1","dataType":{"string":{"collation":"UTF8_BINARY"}}}]}}"# + } else { + #"{"struct":{"fields":[{"name":"col1","dataType":{"string":{}}}]}}"# + } + #expect(schema1 == answer1) let schema2 = try await spark.sql("SELECT 'a' as col1, 'b' as col2").schema - #expect( - schema2 - == #"{"struct":{"fields":[{"name":"col1","dataType":{"string":{"collation":"UTF8_BINARY"}}},{"name":"col2","dataType":{"string":{"collation":"UTF8_BINARY"}}}]}}"# - ) + let answer2 = if await spark.version.starts(with: "4.") { + #"{"struct":{"fields":[{"name":"col1","dataType":{"string":{"collation":"UTF8_BINARY"}}},{"name":"col2","dataType":{"string":{"collation":"UTF8_BINARY"}}}]}}"# + } else { + #"{"struct":{"fields":[{"name":"col1","dataType":{"string":{}}},{"name":"col2","dataType":{"string":{}}}]}}"# + } + #expect(schema2 == answer2) let emptySchema = try await spark.sql("DROP TABLE IF EXISTS nonexistent").schema #expect(emptySchema == #"{"struct":{}}"#) @@ -319,11 +323,12 @@ struct DataFrameTests { let spark = try await SparkSession.builder.getOrCreate() #expect(try await spark.sql("DROP TABLE IF EXISTS t").count() == 0) #expect(try await spark.sql("SHOW TABLES").count() == 0) - #expect(try await spark.sql("CREATE TABLE IF NOT EXISTS t(a INT)").count() == 0) + #expect(try await spark.sql("CREATE TABLE IF NOT EXISTS t(a INT) USING ORC").count() == 0) #expect(try await spark.sql("SHOW TABLES").count() == 1) #expect(try await spark.sql("SELECT * FROM t").count() == 0) #expect(try await spark.sql("INSERT INTO t VALUES (1), (2), (3)").count() == 0) #expect(try await spark.sql("SELECT * FROM t").count() == 3) + #expect(try await spark.sql("DROP TABLE IF EXISTS t").count() == 0) await spark.stop() } @@ -482,20 +487,22 @@ struct DataFrameTests { @Test func lateralJoin() async throws { let spark = try await SparkSession.builder.getOrCreate() - let df1 = try await spark.sql("SELECT * FROM VALUES ('a', '1'), ('b', '2') AS T(a, b)") - let df2 = try await spark.sql("SELECT * FROM VALUES ('c', '2'), ('d', '3') AS S(c, b)") - let expectedCross = [ - Row("a", "1", "c", "2"), - Row("a", "1", "d", "3"), - Row("b", "2", "c", "2"), - Row("b", "2", "d", "3"), - ] - #expect(try await df1.lateralJoin(df2).collect() == expectedCross) - #expect(try await df1.lateralJoin(df2, joinType: "inner").collect() == expectedCross) - - let expected = [Row("b", "2", "c", "2")] - #expect(try await df1.lateralJoin(df2, joinExprs: "T.b = S.b").collect() == expected) - #expect(try await df1.lateralJoin(df2, joinExprs: "T.b = S.b", joinType: "inner").collect() == expected) + if await spark.version.starts(with: "4.") { + let df1 = try await spark.sql("SELECT * FROM VALUES ('a', '1'), ('b', '2') AS T(a, b)") + let df2 = try await spark.sql("SELECT * FROM VALUES ('c', '2'), ('d', '3') AS S(c, b)") + let expectedCross = [ + Row("a", "1", "c", "2"), + Row("a", "1", "d", "3"), + Row("b", "2", "c", "2"), + Row("b", "2", "d", "3"), + ] + #expect(try await df1.lateralJoin(df2).collect() == expectedCross) + #expect(try await df1.lateralJoin(df2, joinType: "inner").collect() == expectedCross) + + let expected = [Row("b", "2", "c", "2")] + #expect(try await df1.lateralJoin(df2, joinExprs: "T.b = S.b").collect() == expected) + #expect(try await df1.lateralJoin(df2, joinExprs: "T.b = S.b", joinType: "inner").collect() == expected) + } await spark.stop() } diff --git a/Tests/SparkConnectTests/DataFrameWriterTests.swift b/Tests/SparkConnectTests/DataFrameWriterTests.swift index da6d190..464e288 100644 --- a/Tests/SparkConnectTests/DataFrameWriterTests.swift +++ b/Tests/SparkConnectTests/DataFrameWriterTests.swift @@ -47,8 +47,10 @@ struct DataFrameWriterTests { func xml() async throws { let tmpDir = "/tmp/" + UUID().uuidString let spark = try await SparkSession.builder.getOrCreate() - try await spark.range(2025).write.option("rowTag", "person").xml(tmpDir) - #expect(try await spark.read.option("rowTag", "person").xml(tmpDir).count() == 2025) + if await spark.version >= "4.0.0" { + try await spark.range(2025).write.option("rowTag", "person").xml(tmpDir) + #expect(try await spark.read.option("rowTag", "person").xml(tmpDir).count() == 2025) + } await spark.stop() } diff --git a/Tests/SparkConnectTests/SQLTests.swift b/Tests/SparkConnectTests/SQLTests.swift index 7f7ae8b..498d3d2 100644 --- a/Tests/SparkConnectTests/SQLTests.swift +++ b/Tests/SparkConnectTests/SQLTests.swift @@ -69,6 +69,13 @@ struct SQLTests { #expect(removeOwner("185") == "*") } + let queriesForSpark4Only: [String] = [ + "create_scala_function.sql", + "create_table_function.sql", + "pipesyntax.sql", + "explain.sql", + ] + #if !os(Linux) @Test func runAll() async throws { @@ -76,6 +83,10 @@ struct SQLTests { for name in try! fm.contentsOfDirectory(atPath: path).sorted() { guard name.hasSuffix(".sql") else { continue } print(name) + if queriesForSpark4Only.contains(name) { + print("Skip query \(name) due to the difference between Spark 3 and 4.") + continue + } let sql = try String(contentsOf: URL(fileURLWithPath: "\(path)/\(name)"), encoding: .utf8) let answer = cleanUp(try await spark.sql(sql).collect().map { $0.toString() }.joined(separator: "\n")) diff --git a/Tests/SparkConnectTests/SparkConnectClientTests.swift b/Tests/SparkConnectTests/SparkConnectClientTests.swift index b3410ef..d5a824b 100644 --- a/Tests/SparkConnectTests/SparkConnectClientTests.swift +++ b/Tests/SparkConnectTests/SparkConnectClientTests.swift @@ -84,15 +84,15 @@ struct SparkConnectClientTests { await client.stop() } -#if !os(Linux) // TODO: Enable this with the offical Spark 4 docker image @Test func jsonToDdl() async throws { let client = SparkConnectClient(remote: TEST_REMOTE) - let _ = try await client.connect(UUID().uuidString) - let json = + let response = try await client.connect(UUID().uuidString) + if response.sparkVersion.version.starts(with: "4.") { + let json = #"{"type":"struct","fields":[{"name":"id","type":"long","nullable":false,"metadata":{}}]}"# - #expect(try await client.jsonToDdl(json) == "id BIGINT NOT NULL") + #expect(try await client.jsonToDdl(json) == "id BIGINT NOT NULL") + } await client.stop() } -#endif } diff --git a/Tests/SparkConnectTests/SparkSessionTests.swift b/Tests/SparkConnectTests/SparkSessionTests.swift index f864b09..f730d9c 100644 --- a/Tests/SparkConnectTests/SparkSessionTests.swift +++ b/Tests/SparkConnectTests/SparkSessionTests.swift @@ -53,7 +53,8 @@ struct SparkSessionTests { @Test func version() async throws { let spark = try await SparkSession.builder.getOrCreate() - #expect(await spark.version.starts(with: "4.0.0")) + let version = await spark.version + #expect(version.starts(with: "4.0.0") || version.starts(with: "3.5.")) await spark.stop() } @@ -80,7 +81,7 @@ struct SparkSessionTests { let tableName = "TABLE_" + UUID().uuidString.replacingOccurrences(of: "-", with: "") let spark = try await SparkSession.builder.getOrCreate() try await SQLHelper.withTable(spark, tableName)({ - _ = try await spark.sql("CREATE TABLE \(tableName) AS VALUES (1), (2)").count() + _ = try await spark.sql("CREATE TABLE \(tableName) USING ORC AS VALUES (1), (2)").count() #expect(try await spark.table(tableName).count() == 2) }) await spark.stop()