Skip to content

Commit 5b996ff

Browse files
committed
Implements Embedding in MySQL
(cherry picked from commit ca9cdce)
1 parent 5b352fb commit 5b996ff

File tree

5 files changed

+65
-6
lines changed

5 files changed

+65
-6
lines changed

java/src/main/java/com/genexus/db/ForEachCursor.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,8 @@ public void setParameterRT(String name, String value)
340340
boolean isLike = false;
341341
if(value.equals("like"))
342342
isLike = true;
343+
else if (value.equals("Distance"))
344+
value = ds.getDistanceFunction();
343345
else if(!value.equals("=") && !value.equals(">") && !value.equals(">=")
344346
&& !value.equals("<=") && !value.equals("<") && !value.equals("<>"))
345347
{

java/src/main/java/com/genexus/db/driver/DataSource.java

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ public class DataSource extends AbstractDataSource
7171
public String jdbcDataSource;
7272

7373
private String namespace;
74+
private String dbmsVersion;
7475

7576
public DataSource(
7677
String name,
@@ -581,7 +582,15 @@ public DataSource copy()
581582
copyDataSource.setConnectionPools(this.getConnectionPools());
582583
return copyDataSource;
583584
}
584-
585+
586+
public String getDistanceFunction() {
587+
String distanceFunction = "DISTANCE";
588+
if (dbms.getId() == GXDBMS.DBMS_MYSQL && getDbmsVersion() != null && getDbmsVersion().contains("MariaDB")) {
589+
distanceFunction = "vec_distance_cosine";
590+
}
591+
return distanceFunction;
592+
}
593+
585594
public String[] concatOp()
586595
{
587596
switch(dbms.getId())
@@ -616,7 +625,15 @@ public void RWPoolRecycle()
616625
{
617626
if (getConnectionPool().getRWConnectionPool(defaultUser) != null)
618627
getConnectionPool().getRWConnectionPool(defaultUser).PoolRecycle();
619-
}
628+
}
629+
630+
public void setDbmsVersion(String dbmsVersion) {
631+
this.dbmsVersion = dbmsVersion;
632+
}
633+
634+
public String getDbmsVersion() {
635+
return dbmsVersion;
636+
}
620637

621638
}
622639

java/src/main/java/com/genexus/db/driver/GXConnection.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ public GXConnection( ModelContext context, int handle, String user, String passw
190190
try
191191
{
192192
version = dma.getDatabaseProductVersion();
193+
dataSource.setDbmsVersion(version);
193194
}
194195
catch (SQLException e)
195196
{

java/src/main/java/com/genexus/db/driver/GXPreparedStatement.java

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import java.math.BigDecimal;
1414
import java.net.MalformedURLException;
1515
import java.net.URL;
16+
import java.nio.ByteBuffer;
1617
import java.sql.Array;
1718
import java.sql.Blob;
1819
import java.sql.Clob;
@@ -1506,14 +1507,35 @@ else if(blobFiles.length < index)
15061507
}
15071508
}
15081509

1510+
private static byte[] floatArrayToByteArray(Float[] floats) {
1511+
ByteBuffer buffer = ByteBuffer.allocate(floats.length * Float.BYTES);
1512+
1513+
for (Float f : floats) {
1514+
if (f != null) {
1515+
buffer.putFloat(f);
1516+
} else {
1517+
buffer.putFloat(0.0f);
1518+
}
1519+
}
1520+
return buffer.array();
1521+
}
1522+
15091523
public void setEmbedding(int index, Float[] value) throws SQLException{
1510-
Array sqlArray = con.createArrayOf("float4", value);
1524+
byte[] bytes = null;
1525+
Array sqlArray = null;
1526+
if (con.getDBMS().getId() == GXDBMS.DBMS_POSTGRESQL)
1527+
sqlArray = con.createArrayOf("float4", value);
1528+
else
1529+
bytes = floatArrayToByteArray(value);
15111530
if (DEBUG)
15121531
{
15131532
log(GXDBDebug.LOG_MAX, "setEmbedding - index : " + index);
15141533
try
15151534
{
1516-
stmt.setArray(index, sqlArray);
1535+
if (con.getDBMS().getId() == GXDBMS.DBMS_POSTGRESQL)
1536+
stmt.setArray(index, sqlArray);
1537+
else
1538+
stmt.setBytes(index, bytes);
15171539
}
15181540
catch (SQLException sqlException)
15191541
{
@@ -1523,7 +1545,10 @@ public void setEmbedding(int index, Float[] value) throws SQLException{
15231545
}
15241546
else
15251547
{
1526-
stmt.setArray(index, sqlArray);
1548+
if (con.getDBMS().getId() == GXDBMS.DBMS_POSTGRESQL)
1549+
stmt.setArray(index, sqlArray);
1550+
else
1551+
stmt.setBytes(index, bytes);
15271552
}
15281553
}
15291554

java/src/main/java/com/genexus/db/driver/GXResultSet.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import java.io.OutputStream;
1010
import java.io.Reader;
1111
import java.math.BigDecimal;
12+
import java.nio.ByteBuffer;
1213
import java.sql.Array;
1314
import java.sql.Blob;
1415
import java.sql.Clob;
@@ -836,7 +837,20 @@ public Float[] getGxembedding (int columnIndex) throws SQLException
836837
if (DEBUG )
837838
log(GXDBDebug.LOG_MAX, "Warning: getEmbedding");
838839

839-
return(Float[]) convertVectorStringToFloatArray(result.getArray(columnIndex).toString());
840+
if (con.getDBMS().getId() == GXDBMS.DBMS_POSTGRESQL)
841+
return convertVectorStringToFloatArray(result.getArray(columnIndex).toString());
842+
else
843+
return byteArrayToFloatObjectArray(result.getBytes(columnIndex));
844+
}
845+
846+
private static Float[] byteArrayToFloatObjectArray(byte[] bytes) {
847+
Float[] floats = new Float[bytes.length / Float.BYTES];
848+
849+
ByteBuffer buffer = ByteBuffer.wrap(bytes);
850+
for (int i = 0; i < floats.length; i++) {
851+
floats[i] = buffer.getFloat();
852+
}
853+
return floats;
840854
}
841855

842856
private static Float[] convertVectorStringToFloatArray(String vectorString) {

0 commit comments

Comments
 (0)