diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ExtendedSpanner.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ExtendedSpanner.java new file mode 100644 index 00000000000..a1b65ecfefb --- /dev/null +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/ExtendedSpanner.java @@ -0,0 +1,44 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +public interface ExtendedSpanner extends Spanner { + /** + * Returns a {@code DatabaseClient} for the given database and given client id. It uses a pool of + * sessions to talk to the database. + * + * + *
{@code
+   * SpannerOptions options = SpannerOptions.newBuilder().build();
+   * Spanner spanner = options.getService();
+   * final String project = "test-project";
+   * final String instance = "test-instance";
+   * final String database = "example-db";
+   * final String client_id = "client_id"
+   * DatabaseId db =
+   *     DatabaseId.of(project, instance, database);
+   *
+   * DatabaseClient dbClient = spanner.getDatabaseClient(db, client_id);
+   * }
+ * + * + */ + default DatabaseClient getDatabaseClient(DatabaseId db, String clientId) { + throw new UnsupportedOperationException( + "getDatabaseClient with clientId is not supported by this default implementation."); + } +} diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java index 63d501fbe63..c4b11cc859d 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java @@ -59,7 +59,7 @@ import javax.annotation.concurrent.GuardedBy; /** Default implementation of the Cloud Spanner interface. */ -class SpannerImpl extends BaseService implements Spanner { +class SpannerImpl extends BaseService implements ExtendedSpanner { private static final Logger logger = Logger.getLogger(SpannerImpl.class.getName()); final TraceWrapper tracer = new TraceWrapper( @@ -254,9 +254,13 @@ public InstanceAdminClient getInstanceAdminClient() { @Override public DatabaseClient getDatabaseClient(DatabaseId db) { + return getDatabaseClient(db, null); + } + + @Override + public DatabaseClient getDatabaseClient(DatabaseId db, String clientId) { synchronized (this) { checkClosed(); - String clientId = null; if (dbClients.containsKey(db) && !dbClients.get(db).isValid()) { // Close the invalidated client and remove it. dbClients.get(db).closeAsync(new ClosedException()); diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java index c1e8839534f..0c73ba11ffb 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionImpl.java @@ -58,6 +58,7 @@ import com.google.cloud.spanner.DatabaseId; import com.google.cloud.spanner.Dialect; import com.google.cloud.spanner.ErrorCode; +import com.google.cloud.spanner.ExtendedSpanner; import com.google.cloud.spanner.Mutation; import com.google.cloud.spanner.Options; import com.google.cloud.spanner.Options.QueryOption; @@ -108,6 +109,7 @@ import java.util.Iterator; import java.util.LinkedList; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.Stack; import java.util.UUID; @@ -157,6 +159,7 @@ class ConnectionImpl implements Connection { private static final ParsedStatement RELEASE_STATEMENT = AbstractStatementParser.getInstance(Dialect.GOOGLE_STANDARD_SQL) .parse(Statement.of("RELEASE s1")); + private static final String CLIENT_ID = "client_id"; /** * Exception that is used to register the stacktrace of the code that opened a {@link Connection}. @@ -251,8 +254,7 @@ static UnitOfWorkType of(TransactionMode transactionMode) { } } - private StatementExecutor.StatementTimeout statementTimeout = - new StatementExecutor.StatementTimeout(); + private StatementTimeout statementTimeout = new StatementTimeout(); private boolean closed = false; private final Spanner spanner; @@ -323,7 +325,25 @@ static UnitOfWorkType of(TransactionMode transactionMode) { EmulatorUtil.maybeCreateInstanceAndDatabase( spanner, options.getDatabaseId(), options.getDialect()); } - this.dbClient = spanner.getDatabaseClient(options.getDatabaseId()); + DatabaseClient tempDbClient = null; + final DatabaseId databaseId = options.getDatabaseId(); + try { + Optional clientIdOpt = extractClientIdOptional(options); + if (clientIdOpt.isPresent() && !clientIdOpt.get().isEmpty()) { + if (this.spanner instanceof ExtendedSpanner) { + ExtendedSpanner extendedSpanner = (ExtendedSpanner) this.spanner; + tempDbClient = extendedSpanner.getDatabaseClient(databaseId, clientIdOpt.get()); + } + } + } catch (Exception e) { + System.err.println( + "WARNING: Failed during DatabaseClient initialization (possibly getting specific ID), falling back to default. Error: " + + e.getMessage()); + } + if (tempDbClient == null) { + tempDbClient = spanner.getDatabaseClient(databaseId); + } + this.dbClient = tempDbClient; this.batchClient = spanner.getBatchClient(options.getDatabaseId()); this.ddlClient = createDdlClient(); this.connectionState = @@ -340,6 +360,14 @@ && getDialect() == Dialect.POSTGRESQL setDefaultTransactionOptions(getDefaultIsolationLevel()); } + private Optional extractClientIdOptional(ConnectionOptions options) { + return Optional.ofNullable(options.getInitialConnectionPropertyValues()) + .map(props -> props.get(CLIENT_ID)) + .map(ConnectionPropertyValue::getValue) + .map(Object::toString) + .filter(id -> !id.isEmpty()); + } + /** Constructor only for test purposes. */ @VisibleForTesting ConnectionImpl( @@ -411,7 +439,7 @@ static Attributes createOpenTelemetryAttributes(DatabaseId databaseId) { } @VisibleForTesting - ConnectionState.Type getConnectionStateType() { + Type getConnectionStateType() { return this.connectionState.getType(); } @@ -500,7 +528,7 @@ private void reset(Context context, boolean inTransaction) { this.connectionState.resetValue(AUTOCOMMIT_DML_MODE, context, inTransaction); this.statementTag = null; - this.statementTimeout = new StatementExecutor.StatementTimeout(); + this.statementTimeout = new StatementTimeout(); this.connectionState.resetValue(DIRECTED_READ, context, inTransaction); this.connectionState.resetValue(SAVEPOINT_SUPPORT, context, inTransaction); this.protoDescriptors = null; @@ -541,8 +569,7 @@ public boolean isClosed() { return closed; } - private T getConnectionPropertyValue( - com.google.cloud.spanner.connection.ConnectionProperty property) { + private T getConnectionPropertyValue(ConnectionProperty property) { return this.connectionState.getValue(property).getValue(); } @@ -562,9 +589,8 @@ private void setConnectionPropertyValue( /** * Sets a connection property value only for the duration of the current transaction. The effects * of this will be undone once the transaction ends, regardless whether the transaction is - * committed or rolled back. 'Local' properties are supported for both {@link - * com.google.cloud.spanner.connection.ConnectionState.Type#TRANSACTIONAL} and {@link - * com.google.cloud.spanner.connection.ConnectionState.Type#NON_TRANSACTIONAL} connection states. + * committed or rolled back. 'Local' properties are supported for both {@link Type#TRANSACTIONAL} + * and {@link Type#NON_TRANSACTIONAL} connection states. * *

NOTE: This feature is not yet exposed in the public API. */ diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionOptions.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionOptions.java index 6f945938df0..fbc59caa174 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionOptions.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionOptions.java @@ -21,6 +21,7 @@ import static com.google.cloud.spanner.connection.ConnectionProperties.AUTO_PARTITION_MODE; import static com.google.cloud.spanner.connection.ConnectionProperties.CHANNEL_PROVIDER; import static com.google.cloud.spanner.connection.ConnectionProperties.CLIENT_CERTIFICATE; +import static com.google.cloud.spanner.connection.ConnectionProperties.CLIENT_ID; import static com.google.cloud.spanner.connection.ConnectionProperties.CLIENT_KEY; import static com.google.cloud.spanner.connection.ConnectionProperties.CREDENTIALS_PROVIDER; import static com.google.cloud.spanner.connection.ConnectionProperties.CREDENTIALS_URL; @@ -539,6 +540,11 @@ public Builder setTracingPrefix(String tracingPrefix) { return this; } + public Builder setClientId(String clientId) { + setConnectionPropertyValue(CLIENT_ID, clientId); + return this; + } + /** @return the {@link ConnectionOptions} */ public ConnectionOptions build() { Preconditions.checkState(this.uri != null, "Connection URI is required"); @@ -603,7 +609,6 @@ private ConnectionOptions(Builder builder) { // Create the initial connection state from the parsed properties in the connection URL. this.initialConnectionState = new ConnectionState(connectionPropertyValues); - // Check that at most one of credentials location, encoded credentials, credentials provider and // OUAuth token has been specified in the connection URI. Preconditions.checkArgument( diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionProperties.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionProperties.java index 54d3461b787..22f541f3bd2 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionProperties.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/connection/ConnectionProperties.java @@ -135,6 +135,14 @@ public class ConnectionProperties { private static final Boolean[] BOOLEANS = new Boolean[] {Boolean.TRUE, Boolean.FALSE}; + static final ConnectionProperty CLIENT_ID = + create( + "client_id", + "Client Id to use for this connection. Can only be set at the start up time", + null, + StringValueConverter.INSTANCE, + Context.STARTUP); + static final ConnectionProperty CONNECTION_STATE_TYPE = create( "connection_state_type", diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerImplTest.java index 3cf13dc58d3..af73b2b7e5f 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SpannerImplTest.java @@ -358,6 +358,43 @@ public void testCreateInstanceAdminClient_whenMockAdminSettings_assertException( assertNotNull(instanceAdminClient); } + @Test + public void testGetDatabaseClient_when_clientId_is_not_null() { + String dbName = + String.format("projects/p1/instances/i1/databases/%s", UUID.randomUUID().toString()); + DatabaseId db = DatabaseId.of(dbName); + + Mockito.when(spannerOptions.getTransportOptions()) + .thenReturn(GrpcTransportOptions.newBuilder().build()); + Mockito.when(spannerOptions.getSessionPoolOptions()) + .thenReturn(SessionPoolOptions.newBuilder().setMinSessions(0).build()); + Mockito.when(spannerOptions.getDatabaseRole()).thenReturn("role"); + + DatabaseClientImpl databaseClient = + (DatabaseClientImpl) impl.getDatabaseClient(db, "clientId-1"); + assertThat(databaseClient.clientId).isEqualTo("clientId-1"); + + // Get same db client again. + DatabaseClientImpl databaseClient1 = + (DatabaseClientImpl) impl.getDatabaseClient(db, "clientId-1"); + assertThat(databaseClient1.clientId).isEqualTo(databaseClient.clientId); + + // Get a db client for a different database. + String dbName2 = + String.format("projects/p1/instances/i1/databases/%s", UUID.randomUUID().toString()); + DatabaseId db2 = DatabaseId.of(dbName2); + DatabaseClientImpl databaseClient2 = + (DatabaseClientImpl) impl.getDatabaseClient(db2, "clientId-1"); + assertThat(databaseClient2.clientId).isEqualTo("clientId-1"); + + // Getting a new database client for an invalidated database should use the same client id. + databaseClient.pool.setResourceNotFoundException( + new DatabaseNotFoundException(DoNotConstructDirectly.ALLOWED, "not found", null, null)); + DatabaseClientImpl revalidated = (DatabaseClientImpl) impl.getDatabaseClient(db, "clientId-1"); + assertThat(revalidated).isNotSameInstanceAs(databaseClient); + assertThat(revalidated.clientId).isEqualTo(databaseClient.clientId); + } + private void closeSpannerAndIncludeStacktrace(Spanner spanner) { spanner.close(); }