From 2fc644c672d7e1b782c3f3635b0eb3df744737db Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 17 Jun 2024 14:35:38 +0000 Subject: [PATCH 01/17] add dockerfile --- Dockerfile-rocm | 130 ++++++++++++++++++++++++++ backends/python/Makefile-flash-att-v2 | 21 +++++ 2 files changed, 151 insertions(+) create mode 100644 Dockerfile-rocm create mode 100644 backends/python/Makefile-flash-att-v2 diff --git a/Dockerfile-rocm b/Dockerfile-rocm new file mode 100644 index 00000000..f9e56631 --- /dev/null +++ b/Dockerfile-rocm @@ -0,0 +1,130 @@ +FROM rocm/dev-ubuntu-22.04:6.0.2 AS base-builder + +ENV SCCACHE=0.5.4 +ENV RUSTC_WRAPPER=/usr/local/bin/sccache +ENV PATH="/root/.cargo/bin:${PATH}" + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + curl \ + libssl-dev \ + pkg-config \ + && rm -rf /var/lib/apt/lists/* + +# Donwload and configure sccache +RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \ + chmod +x /usr/local/bin/sccache + +RUN curl https://sh.rustup.rs -sSf | bash -s -- -y +RUN cargo install cargo-chef --locked + +FROM base-builder AS planner + +WORKDIR /usr/src + +COPY backends backends +COPY core core +COPY router router +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN cargo chef prepare --recipe-path recipe.json + +FROM base-builder AS builder + +ARG CUDA_COMPUTE_CAP=80 +ARG GIT_SHA +ARG DOCKER_LABEL + +# sccache specific variables +ARG ACTIONS_CACHE_URL +ARG ACTIONS_RUNTIME_TOKEN +ARG SCCACHE_GHA_ENABLED + +WORKDIR /usr/src + +COPY --from=planner /usr/src/recipe.json recipe.json + +RUN cargo chef cook --release --features python --no-default-features --recipe-path recipe.json && sccache -s + +COPY backends backends +COPY core core +COPY router router +COPY Cargo.toml ./ +COPY Cargo.lock ./ + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + unzip \ + && rm -rf /var/lib/apt/lists/* + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +COPY proto proto + +FROM builder as http-builder + +RUN cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s + +FROM builder as grpc-builder + +RUN cargo build --release --bin text-embeddings-router -F python -F grpc --no-default-features && sccache -s + +FROM rocm/dev-ubuntu-22.04:6.0.2 as base + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + git \ + python3-dev \ + rocthrust-dev \ + hipsparse-dev \ + hipblas-dev \ + hipblaslt-dev \ + rocblas-dev \ + hiprand-dev \ + rocrand-dev \ + && rm -rf /var/lib/apt/lists/* + + +# Keep in sync with `server/pyproject.toml +ARG MAMBA_VERSION=23.1.0-1 +ARG PYTORCH_VERSION='2.3.0' +ARG ROCM_VERSION='6.0.2' +ARG PYTHON_VERSION='3.10.10' +# Automatically set by buildx +ARG TARGETPLATFORM +ENV PATH /opt/conda/bin:$PATH + +RUN curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-x86_64.sh" +RUN chmod +x ~/mambaforge.sh && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + mamba init && \ + rm ~/mambaforge.sh + +# Install flash-attention, torch dependencies +RUN pip install numpy einops ninja --no-cache-dir + +RUN pip install torch --index-url https://download.pytorch.org/whl/rocm6.0 + +ARG DEFAULT_USE_FLASH_ATTENTION=True +COPY backends/python/Makefile-flash-att-v2 Makefile-flash-att-v2 +RUN make -f Makefile-flash-att-v2 install-flash-attention-v2-rocm + +ENV HUGGINGFACE_HUB_CACHE=/data \ + PORT=80 \ + USE_FLASH_ATTENTION=$DEFAULT_USE_FLASH_ATTENTION + +FROM base as grpc + +COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] + +FROM base + +COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router + +ENTRYPOINT ["text-embeddings-router"] +CMD ["--json-output"] diff --git a/backends/python/Makefile-flash-att-v2 b/backends/python/Makefile-flash-att-v2 new file mode 100644 index 00000000..ba90a74d --- /dev/null +++ b/backends/python/Makefile-flash-att-v2 @@ -0,0 +1,21 @@ +flash_att_v2_commit_cuda := v2.5.9.post1 +flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6 + +build-flash-attention-v2-cuda: + pip install -U packaging wheel + pip install flash-attn==$(flash_att_v2_commit_cuda) + +install-flash-attention-v2-cuda: build-flash-attention-v2-cuda + echo "Flash v2 installed" + +build-flash-attention-v2-rocm: + if [ ! -d 'flash-attention-v2' ]; then \ + pip install -U packaging ninja --no-cache-dir && \ + git clone https://github.com/ROCm/flash-attention.git flash-attention-v2 && \ + cd flash-attention-v2 && git fetch && git checkout $(flash_att_v2_commit_rocm) && \ + git submodule update --init --recursive && GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build; \ + fi + +install-flash-attention-v2-rocm: build-flash-attention-v2-rocm + cd flash-attention-v2 && \ + GPU_ARCHS="gfx90a;gfx942" PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py install From 37d29316243298970ba2d3e889f0757dc2d74e7a Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 18 Jun 2024 13:05:09 +0000 Subject: [PATCH 02/17] working cls pooling --- backends/python/server/pyproject.toml | 7 +- backends/python/server/requirements.txt | 12 --- .../server/text_embeddings_server/cli.py | 3 +- .../text_embeddings_server/models/__init__.py | 18 ++-- .../models/flash_bert.py | 59 ++++-------- .../server/text_embeddings_server/server.py | 3 +- .../utils/flash_attn.py | 92 ------------------- router/src/lib.rs | 1 + 8 files changed, 38 insertions(+), 157 deletions(-) delete mode 100644 backends/python/server/text_embeddings_server/utils/flash_attn.py diff --git a/backends/python/server/pyproject.toml b/backends/python/server/pyproject.toml index 96fcaf9e..839ff27a 100644 --- a/backends/python/server/pyproject.toml +++ b/backends/python/server/pyproject.toml @@ -20,7 +20,7 @@ loguru = "^0.6.0" opentelemetry-api = "^1.15.0" opentelemetry-exporter-otlp = "^1.15.0" opentelemetry-instrumentation-grpc = "^0.36b0" -torch = { version = "^2.0.1" } +torch = { version = "==2.3.1" } [tool.poetry.extras] @@ -33,6 +33,11 @@ name = "pytorch-gpu-src" url = "https://download.pytorch.org/whl/cu118" priority = "explicit" +[[tool.poetry.source]] +name = "pytorch-gpu-src-rocm" +url = "https://download.pytorch.org/whl/rocm6.0" +priority = "explicit" + [tool.pytest.ini_options] markers = ["private: marks tests as requiring an admin hf token (deselect with '-m \"not private\"')"] diff --git a/backends/python/server/requirements.txt b/backends/python/server/requirements.txt index 89ca314d..2d089e41 100644 --- a/backends/python/server/requirements.txt +++ b/backends/python/server/requirements.txt @@ -4,20 +4,13 @@ charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13" -filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13" -fsspec==2023.9.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13" -huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13" idna==3.4 ; python_version >= "3.9" and python_version < "3.13" -jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" -markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13" -mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13" -networkx==3.1 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" @@ -27,15 +20,10 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" -packaging==23.1 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13" -pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13" setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13" -sympy==1.12 ; python_version >= "3.9" and python_version < "3.13" -torch==2.0.1 ; python_version >= "3.9" and python_version < "3.13" -tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13" diff --git a/backends/python/server/text_embeddings_server/cli.py b/backends/python/server/text_embeddings_server/cli.py index 4c627515..70e60d80 100644 --- a/backends/python/server/text_embeddings_server/cli.py +++ b/backends/python/server/text_embeddings_server/cli.py @@ -23,6 +23,7 @@ def serve( logger_level: str = "INFO", json_output: bool = False, otlp_endpoint: Optional[str] = None, + pooling_mode: Optional[str] = None, ): # Remove default handler logger.remove() @@ -47,7 +48,7 @@ def serve( # Downgrade enum into str for easier management later on dtype = None if dtype is None else dtype.value - server.serve(model_path, dtype, uds_path) + server.serve(model_path, dtype, uds_path, pooling_mode) if __name__ == "__main__": diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 47867187..7f480b33 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -15,17 +15,19 @@ torch.set_grad_enabled(False) FLASH_ATTENTION = True -try: - from text_embeddings_server.models.flash_bert import FlashBert -except ImportError as e: - logger.warning(f"Could not import Flash Attention enabled models: {e}") - FLASH_ATTENTION = False +# try: +from text_embeddings_server.models.flash_bert import FlashBert +# except ImportError as e: +# logger.warning(f"Could not import Flash Attention enabled models: {e}") +# FLASH_ATTENTION = False if FLASH_ATTENTION: __all__.append(FlashBert) -def get_model(model_path: Path, dtype: Optional[str]): +class + +def get_model(model_path: Path, dtype: Optional[str], pooling_mode: str): if dtype == "float32": dtype = torch.float32 elif dtype == "float16": @@ -52,8 +54,8 @@ def get_model(model_path: Path, dtype: Optional[str]): and dtype in [torch.float16, torch.bfloat16] and FLASH_ATTENTION ): - return FlashBert(model_path, device, dtype) + return FlashBert(model_path, device, dtype, pooling_mode) else: - return DefaultModel(model_path, device, dtype) + return DefaultModel(model_path, device, dtype, pooling_mode) raise NotImplementedError diff --git a/backends/python/server/text_embeddings_server/models/flash_bert.py b/backends/python/server/text_embeddings_server/models/flash_bert.py index 50b8d70d..67176f2c 100644 --- a/backends/python/server/text_embeddings_server/models/flash_bert.py +++ b/backends/python/server/text_embeddings_server/models/flash_bert.py @@ -8,46 +8,15 @@ from transformers.models.bert import BertConfig from opentelemetry import trace -# Flash attention imports -import dropout_layer_norm - from text_embeddings_server.models import Model from text_embeddings_server.models.types import FlashBatch, Embedding -from text_embeddings_server.utils.flash_attn import attention +from text_embeddings_server.layers.attention import attention +from text_embeddings_server.layers.layernorm import FastLayerNorm +from loguru import logger tracer = trace.get_tracer(__name__) -class FastLayerNorm: - def __init__(self, prefix, handle, device, dtype, config: BertConfig): - self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device) - self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device) - self.variance_epsilon = config.layer_norm_eps - - def forward(self, hidden_states, residual=None): - normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.weight, - self.bias, - None, - None, - None, - None, - 0.0, - self.variance_epsilon, - 1.0, - 0, - None, - False, - False, - ) - if res is None: - res = hidden_states - - return normed_hidden_states, res - - class BertEmbeddings: def __init__(self, prefix, handle, device, dtype, config: BertConfig): self.word_embeddings_weight = ( @@ -217,7 +186,7 @@ def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s): embeddings = self.embeddings.forward(input_ids, token_type_ids, position_ids) encoder_outputs = self.encoder.forward(embeddings, cu_seqlens, max_s) - return encoder_outputs[cu_seqlens[:-1]] + return encoder_outputs class FlashBert(Model): @@ -236,6 +205,7 @@ def batch_type(self) -> Type[FlashBatch]: @tracer.start_as_current_span("embed") def embed(self, batch: FlashBatch) -> List[Embedding]: + logger.info(f"batch.input_ids {batch.input_ids}") embedding = self.model.forward( input_ids=batch.input_ids, token_type_ids=batch.token_type_ids, @@ -243,11 +213,16 @@ def embed(self, batch: FlashBatch) -> List[Embedding]: cu_seqlens=batch.cu_seqlens, max_s=batch.max_s, ) - cpu_results = embedding.view(-1).tolist() - return [ - Embedding( - values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size] - ) - for i in range(len(batch)) - ] + if True: + embedding = embedding[batch.cu_seqlens[:-1]] + logger.info(f"embedding {embedding.shape}") + cpu_results = embedding.view(-1).tolist() + + return [ + Embedding( + values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size] + ) + for i in range(len(batch)) + ] + elif diff --git a/backends/python/server/text_embeddings_server/server.py b/backends/python/server/text_embeddings_server/server.py index d0a43ace..2c99cf79 100644 --- a/backends/python/server/text_embeddings_server/server.py +++ b/backends/python/server/text_embeddings_server/server.py @@ -37,6 +37,7 @@ def serve( model_path: Path, dtype: Optional[str], uds_path: Path, + pooling_mode: Optional[str], ): async def serve_inner( model_path: Path, @@ -45,7 +46,7 @@ async def serve_inner( unix_socket = f"unix://{uds_path}" try: - model = get_model(model_path, dtype) + model = get_model(model_path, dtype, pooling_mode) except Exception: logger.exception("Error when initializing model") raise diff --git a/backends/python/server/text_embeddings_server/utils/flash_attn.py b/backends/python/server/text_embeddings_server/utils/flash_attn.py deleted file mode 100644 index 1d325351..00000000 --- a/backends/python/server/text_embeddings_server/utils/flash_attn.py +++ /dev/null @@ -1,92 +0,0 @@ -import os -import torch - -from loguru import logger - -if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": - raise ImportError("`USE_FLASH_ATTENTION` is false.") - -if not torch.cuda.is_available(): - raise ImportError("CUDA is not available") - -major, minor = torch.cuda.get_device_capability() -is_sm75 = major == 7 and minor == 5 -is_sm8x = major == 8 and minor >= 0 -is_sm90 = major == 9 and minor == 0 - -HAS_FLASH_ATTN = False -HAS_FLASH_ATTN_V2 = False -try: - try: - import flash_attn_2_cuda - except ImportError: - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention v2 with `cd server && make install install-flash-attention-v2`" - ) - if not (is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported for " - "Flash Attention V2" - ) - HAS_FLASH_ATTN_V2 = True -except ImportError as e: - try: - import flash_attn_cuda - except ImportError: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e - - if not (is_sm75 or is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported" - ) from e - logger.warning(f"Unable to use Flash Attention V2: {e}") - HAS_FLASH_ATTN = True - - -def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False): - if HAS_FLASH_ATTN_V2: - return flash_attn_2_cuda.varlen_fwd( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - 0.0, - softmax_scale, - False, - is_causal, - -1, - -1, - False, - None, - ) - - if HAS_FLASH_ATTN: - return flash_attn_cuda.fwd( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - 0.0, - softmax_scale, - False, - is_causal, - False, - 0, - None, - ) - - raise NotImplementedError("flash attention is not installed") diff --git a/router/src/lib.rs b/router/src/lib.rs index 3801af8a..e387b3cb 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -198,6 +198,7 @@ pub async fn run( backend_model_type, uds_path.unwrap_or("/tmp/text-embeddings-inference-server".to_string()), otlp_endpoint.clone(), + pooling.to_string(), ) .context("Could not create backend")?; backend From 8584b6d4480794625eb62e1e2e3af80538a26f2a Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 18 Jun 2024 13:07:45 +0000 Subject: [PATCH 03/17] add layers --- .../text_embeddings_server/layers/__init__.py | 0 .../layers/attention/__init__.py | 11 +++ .../layers/attention/cuda.py | 92 +++++++++++++++++++ .../layers/attention/rocm.py | 45 +++++++++ .../layers/layernorm.py | 54 +++++++++++ .../utils/import_utils.py | 12 +++ 6 files changed, 214 insertions(+) create mode 100644 backends/python/server/text_embeddings_server/layers/__init__.py create mode 100644 backends/python/server/text_embeddings_server/layers/attention/__init__.py create mode 100644 backends/python/server/text_embeddings_server/layers/attention/cuda.py create mode 100644 backends/python/server/text_embeddings_server/layers/attention/rocm.py create mode 100644 backends/python/server/text_embeddings_server/layers/layernorm.py create mode 100644 backends/python/server/text_embeddings_server/utils/import_utils.py diff --git a/backends/python/server/text_embeddings_server/layers/__init__.py b/backends/python/server/text_embeddings_server/layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backends/python/server/text_embeddings_server/layers/attention/__init__.py b/backends/python/server/text_embeddings_server/layers/attention/__init__.py new file mode 100644 index 00000000..9cce5d34 --- /dev/null +++ b/backends/python/server/text_embeddings_server/layers/attention/__init__.py @@ -0,0 +1,11 @@ +from text_embeddings_server.utils.import_utils import SYSTEM +import os + +if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": + raise ImportError("`USE_FLASH_ATTENTION` is false.") +if SYSTEM == "cuda": + from .cuda import attention +elif SYSTEM == "rocm": + from .rocm import attention +else: + raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") diff --git a/backends/python/server/text_embeddings_server/layers/attention/cuda.py b/backends/python/server/text_embeddings_server/layers/attention/cuda.py new file mode 100644 index 00000000..1d325351 --- /dev/null +++ b/backends/python/server/text_embeddings_server/layers/attention/cuda.py @@ -0,0 +1,92 @@ +import os +import torch + +from loguru import logger + +if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": + raise ImportError("`USE_FLASH_ATTENTION` is false.") + +if not torch.cuda.is_available(): + raise ImportError("CUDA is not available") + +major, minor = torch.cuda.get_device_capability() +is_sm75 = major == 7 and minor == 5 +is_sm8x = major == 8 and minor >= 0 +is_sm90 = major == 9 and minor == 0 + +HAS_FLASH_ATTN = False +HAS_FLASH_ATTN_V2 = False +try: + try: + import flash_attn_2_cuda + except ImportError: + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention v2 with `cd server && make install install-flash-attention-v2`" + ) + if not (is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported for " + "Flash Attention V2" + ) + HAS_FLASH_ATTN_V2 = True +except ImportError as e: + try: + import flash_attn_cuda + except ImportError: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + + if not (is_sm75 or is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported" + ) from e + logger.warning(f"Unable to use Flash Attention V2: {e}") + HAS_FLASH_ATTN = True + + +def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False): + if HAS_FLASH_ATTN_V2: + return flash_attn_2_cuda.varlen_fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + is_causal, + -1, + -1, + False, + None, + ) + + if HAS_FLASH_ATTN: + return flash_attn_cuda.fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + is_causal, + False, + 0, + None, + ) + + raise NotImplementedError("flash attention is not installed") diff --git a/backends/python/server/text_embeddings_server/layers/attention/rocm.py b/backends/python/server/text_embeddings_server/layers/attention/rocm.py new file mode 100644 index 00000000..365e5451 --- /dev/null +++ b/backends/python/server/text_embeddings_server/layers/attention/rocm.py @@ -0,0 +1,45 @@ +import os +import torch +from text_embeddings_server.utils.import_utils import SYSTEM +from loguru import logger + +major, minor = torch.cuda.get_device_capability() +is_sm75 = major == 7 and minor == 5 + +if SYSTEM == "rocm": + try: + import flash_attn_2_cuda + + logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") + except ImportError as e: + if major >= 8 or is_sm75: + architecture_suffix = f"-{SYSTEM}" + raise ImportError(f"Flash Attention V2 is not installed. {e}") + else: + for idx in range(torch.cuda.device_count()): + name = torch.cuda.get_device_name(idx) + if "MI210" not in name and "MI250" not in name and "MI300" not in name: + raise ImportError( + f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" + ) + raise ImportError( + f"AMD GPU with ROCm capability {major} {minor} is not supported" + ) from e + +def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False): + return flash_attn_2_cuda.varlen_fwd( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + is_causal, + False, + None, + ) \ No newline at end of file diff --git a/backends/python/server/text_embeddings_server/layers/layernorm.py b/backends/python/server/text_embeddings_server/layers/layernorm.py new file mode 100644 index 00000000..abd9e676 --- /dev/null +++ b/backends/python/server/text_embeddings_server/layers/layernorm.py @@ -0,0 +1,54 @@ +import torch +from text_embeddings_server.utils.import_utils import SYSTEM + +from transformers.models.bert import BertConfig + +if SYSTEM == "cuda": + import dropout_layer_norm + + class FastLayerNorm: + def __init__(self, prefix, handle, device, dtype, config: BertConfig): + self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device) + self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device) + self.variance_epsilon = config.layer_norm_eps + + def forward(self, hidden_states, residual=None): + normed_hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + self.bias, + None, + None, + None, + None, + 0.0, + self.variance_epsilon, + 1.0, + 0, + None, + False, + False, + ) + if residual is None: + residual = hidden_states + + return normed_hidden_states, residual + +elif SYSTEM == "rocm": + class FastLayerNorm: + def __init__(self, prefix, handle, device, dtype, config: BertConfig): + self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device) + self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device) + self.variance_epsilon = config.layer_norm_eps + + def forward(self, hidden_states, residual=None): + if residual is not None: + hidden_states += residual + residual = hidden_states + + hidden_states = torch.nn.functional.layer_norm(hidden_states, self.weight.shape, self.weight, self.bias, eps=self.variance_epsilon) + + return hidden_states, residual +else: + raise ValueError("System not recognized") \ No newline at end of file diff --git a/backends/python/server/text_embeddings_server/utils/import_utils.py b/backends/python/server/text_embeddings_server/utils/import_utils.py new file mode 100644 index 00000000..83394eaa --- /dev/null +++ b/backends/python/server/text_embeddings_server/utils/import_utils.py @@ -0,0 +1,12 @@ +import torch +from loguru import logger + +SYSTEM = None +if torch.version.hip is not None: + SYSTEM = "rocm" +elif torch.version.cuda is not None and torch.cuda.is_available(): + SYSTEM = "cuda" +else: + SYSTEM = "cpu" + +logger.info(f"Python backend: detected system {SYSTEM}") From 2a2993a38655132fe4c9367a40b5273f783adc90 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 19 Jun 2024 08:48:26 +0000 Subject: [PATCH 04/17] support mean pooling in python backend --- .../text_embeddings_server/layers/pooling.py | 22 +++++++++++++++++++ .../text_embeddings_server/models/__init__.py | 2 -- .../models/default_model.py | 4 +++- .../models/flash_bert.py | 22 ++++++++++++++----- backends/python/src/lib.rs | 6 +++-- backends/python/src/management.rs | 4 ++++ backends/src/lib.rs | 4 ++++ router/src/lib.rs | 17 +++++++++----- 8 files changed, 64 insertions(+), 17 deletions(-) create mode 100644 backends/python/server/text_embeddings_server/layers/pooling.py diff --git a/backends/python/server/text_embeddings_server/layers/pooling.py b/backends/python/server/text_embeddings_server/layers/pooling.py new file mode 100644 index 00000000..1bccbc57 --- /dev/null +++ b/backends/python/server/text_embeddings_server/layers/pooling.py @@ -0,0 +1,22 @@ +import torch +from flash_attn.bert_padding import pad_input + +from loguru import logger + +def mean_pooling(embedding, cu_seqlens, max_s): + # Ideally, rust would pass `indices` to the FlashBatch. + seqlens = cu_seqlens[1:].clone() + seqlens[0] = cu_seqlens[1] + seqlens[1:] -= cu_seqlens[1:-1] + batch_size = len(seqlens) + + # Example: indices = [0, 1, 2, 3, 7, 8, 9, 10, 11, 12, 13] + mask = torch.zeros(batch_size, max_s, dtype=torch.int32, device=cu_seqlens.device) + mask[torch.arange(max_s) < seqlens[:, None].cpu()] = 1 + indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() + + embedding_padded = pad_input(embedding, indices, batch_size, max_s) + + sum_embeddings = torch.sum(embedding_padded, 1) + + return sum_embeddings / seqlens[:, None] \ No newline at end of file diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index 7f480b33..c606efc9 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -25,8 +25,6 @@ __all__.append(FlashBert) -class - def get_model(model_path: Path, dtype: Optional[str], pooling_mode: str): if dtype == "float32": dtype = torch.float32 diff --git a/backends/python/server/text_embeddings_server/models/default_model.py b/backends/python/server/text_embeddings_server/models/default_model.py index dc39fdc8..17ad4589 100644 --- a/backends/python/server/text_embeddings_server/models/default_model.py +++ b/backends/python/server/text_embeddings_server/models/default_model.py @@ -8,14 +8,16 @@ from text_embeddings_server.models import Model from text_embeddings_server.models.types import PaddedBatch, Embedding +from typing import Optional tracer = trace.get_tracer(__name__) class DefaultModel(Model): - def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): + def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype, pooling_mode: Optional[str]): model = AutoModel.from_pretrained(model_path).to(dtype).to(device) self.hidden_size = model.config.hidden_size + self.pooling_mode = pooling_mode self.has_position_ids = ( inspect.signature(model.forward).parameters.get("position_ids", None) diff --git a/backends/python/server/text_embeddings_server/models/flash_bert.py b/backends/python/server/text_embeddings_server/models/flash_bert.py index 67176f2c..60be0002 100644 --- a/backends/python/server/text_embeddings_server/models/flash_bert.py +++ b/backends/python/server/text_embeddings_server/models/flash_bert.py @@ -12,7 +12,8 @@ from text_embeddings_server.models.types import FlashBatch, Embedding from text_embeddings_server.layers.attention import attention from text_embeddings_server.layers.layernorm import FastLayerNorm -from loguru import logger +from text_embeddings_server.layers.pooling import mean_pooling +from typing import Optional tracer = trace.get_tracer(__name__) @@ -190,12 +191,13 @@ def forward(self, input_ids, token_type_ids, position_ids, cu_seqlens, max_s): class FlashBert(Model): - def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): + def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype, pooling_mode: Optional[str]): config = BertConfig.from_pretrained(model_path) with safe_open(model_path / "model.safetensors", framework="pt") as f: model = FlashBertModel(f, device, dtype, config) self.hidden_size = config.hidden_size + self.pooling_mode = pooling_mode super(FlashBert, self).__init__(model=model, dtype=dtype, device=device) @@ -205,7 +207,6 @@ def batch_type(self) -> Type[FlashBatch]: @tracer.start_as_current_span("embed") def embed(self, batch: FlashBatch) -> List[Embedding]: - logger.info(f"batch.input_ids {batch.input_ids}") embedding = self.model.forward( input_ids=batch.input_ids, token_type_ids=batch.token_type_ids, @@ -214,9 +215,8 @@ def embed(self, batch: FlashBatch) -> List[Embedding]: max_s=batch.max_s, ) - if True: + if self.pooling_mode == "cls": embedding = embedding[batch.cu_seqlens[:-1]] - logger.info(f"embedding {embedding.shape}") cpu_results = embedding.view(-1).tolist() return [ @@ -225,4 +225,14 @@ def embed(self, batch: FlashBatch) -> List[Embedding]: ) for i in range(len(batch)) ] - elif + elif self.pooling_mode == "mean": + res = mean_pooling(embedding, batch.cu_seqlens, batch.max_s) + return [ + Embedding( + values=res[i] + ) + for i in range(len(batch)) + ] + + else: + raise NotImplementedError(f"Pooling {self.pooling_mode} is not implemented in the python backend") \ No newline at end of file diff --git a/backends/python/src/lib.rs b/backends/python/src/lib.rs index 195f1d37..ef33b7d2 100644 --- a/backends/python/src/lib.rs +++ b/backends/python/src/lib.rs @@ -23,6 +23,7 @@ impl PythonBackend { uds_path: String, otlp_endpoint: Option, otlp_service_name: String, + pooling_mode: String, ) -> Result { match model_type { ModelType::Classifier => { @@ -31,8 +32,8 @@ impl PythonBackend { )) } ModelType::Embedding(pool) => { - if pool != Pool::Cls { - return Err(BackendError::Start(format!("{pool:?} is not supported"))); + if pool != Pool::Cls && pool != Pool::Mean { + return Err(BackendError::Start(format!("{pool:?} is not supported in the TEI Python backend. Please open an issue."))); } pool } @@ -44,6 +45,7 @@ impl PythonBackend { &uds_path, otlp_endpoint, otlp_service_name, + pooling_mode, )?; let tokio_runtime = tokio::runtime::Builder::new_current_thread() .enable_all() diff --git a/backends/python/src/management.rs b/backends/python/src/management.rs index 911c6984..2044a3e0 100644 --- a/backends/python/src/management.rs +++ b/backends/python/src/management.rs @@ -22,6 +22,7 @@ impl BackendProcess { uds_path: &str, otlp_endpoint: Option, otlp_service_name: String, + pooling_mode: String, ) -> Result { // Get UDS path let uds = Path::new(uds_path); @@ -52,6 +53,9 @@ impl BackendProcess { python_server_args.push("--otlp-service-name".to_owned()); python_server_args.push(otlp_service_name); + python_server_args.push("--pooling-mode".to_owned()); + python_server_args.push(pooling_mode); + // Copy current process env let envs: Vec<(OsString, OsString)> = env::vars_os().collect(); diff --git a/backends/src/lib.rs b/backends/src/lib.rs index d332b4a7..db27cddc 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -39,6 +39,7 @@ impl Backend { uds_path: String, otlp_endpoint: Option, otlp_service_name: String, + pooling_mode: String, ) -> Result { let (backend_sender, backend_receiver) = mpsc::unbounded_channel(); @@ -49,6 +50,7 @@ impl Backend { uds_path, otlp_endpoint, otlp_service_name, + pooling_mode, )?; let padded_model = backend.is_padded(); let max_batch_size = backend.max_batch_size(); @@ -138,6 +140,7 @@ fn init_backend( uds_path: String, otlp_endpoint: Option, otlp_service_name: String, + pooling_mode: String, ) -> Result, BackendError> { if cfg!(feature = "candle") { #[cfg(feature = "candle")] @@ -158,6 +161,7 @@ fn init_backend( uds_path, otlp_endpoint, otlp_service_name, + pooling_mode, ) }) .join() diff --git a/router/src/lib.rs b/router/src/lib.rs index 2f2fec29..03f8fc41 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -105,7 +105,7 @@ pub async fn run( serde_json::from_str(&config).context("Failed to parse `config.json`")?; // Set model type from config - let backend_model_type = get_backend_model_type(&config, &model_root, pooling)?; + let backend_model_type = get_backend_model_type(&config, &model_root, &pooling)?; // Info model type let model_type = match &backend_model_type { @@ -191,6 +191,11 @@ pub async fn run( } }); + let pooling_str = match pooling { + Some(pool) => pool.to_string(), + None => "none".to_string(), + }; + // Create backend tracing::info!("Starting model backend"); let backend = text_embeddings_backend::Backend::new( @@ -200,7 +205,7 @@ pub async fn run( uds_path.unwrap_or("/tmp/text-embeddings-inference-server".to_string()), otlp_endpoint.clone(), otlp_service_name.clone(), - pooling.to_string(), + pooling_str, ) .context("Could not create backend")?; backend @@ -307,10 +312,10 @@ pub async fn run( fn get_backend_model_type( config: &ModelConfig, model_root: &Path, - pooling: Option, + pooling: &Option, ) -> Result { for arch in &config.architectures { - if Some(text_embeddings_backend::Pool::Splade) == pooling && arch.ends_with("MaskedLM") { + if Some(text_embeddings_backend::Pool::Splade) == *pooling && arch.ends_with("MaskedLM") { return Ok(text_embeddings_backend::ModelType::Embedding( text_embeddings_backend::Pool::Splade, )); @@ -324,7 +329,7 @@ fn get_backend_model_type( } } - if Some(text_embeddings_backend::Pool::Splade) == pooling { + if Some(text_embeddings_backend::Pool::Splade) == *pooling { return Err(anyhow!( "Splade pooling is not supported: model is not a ForMaskedLM model" )); @@ -332,7 +337,7 @@ fn get_backend_model_type( // Set pooling let pool = match pooling { - Some(pool) => pool, + Some(pool) => pool.clone(), None => { // Load pooling config let config_path = model_root.join("1_Pooling/config.json"); From 36b3a72dadd7973059e796141dbc1525e8b5c909 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 19 Jun 2024 09:06:46 +0000 Subject: [PATCH 05/17] fix dockerfile and install --- Dockerfile-rocm | 5 +++++ backends/python/server/pyproject.toml | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/Dockerfile-rocm b/Dockerfile-rocm index f9e56631..152fa0a0 100644 --- a/Dockerfile-rocm +++ b/Dockerfile-rocm @@ -111,6 +111,11 @@ ARG DEFAULT_USE_FLASH_ATTENTION=True COPY backends/python/Makefile-flash-att-v2 Makefile-flash-att-v2 RUN make -f Makefile-flash-att-v2 install-flash-attention-v2-rocm +# Install python backend +COPY backends/python/server /tei_backends/python/server +COPY backends/proto tei_backends/proto +RUN make -C /tei_backends/python/server install + ENV HUGGINGFACE_HUB_CACHE=/data \ PORT=80 \ USE_FLASH_ATTENTION=$DEFAULT_USE_FLASH_ATTENTION diff --git a/backends/python/server/pyproject.toml b/backends/python/server/pyproject.toml index 839ff27a..8fbc0008 100644 --- a/backends/python/server/pyproject.toml +++ b/backends/python/server/pyproject.toml @@ -15,12 +15,13 @@ grpcio-status = "^1.51.1" grpcio-reflection = "^1.51.1" grpc-interceptor = "^0.15.0" typer = "^0.6.1" -safetensors = "^0.3.2" +safetensors = "^0.4.0" loguru = "^0.6.0" opentelemetry-api = "^1.15.0" opentelemetry-exporter-otlp = "^1.15.0" opentelemetry-instrumentation-grpc = "^0.36b0" torch = { version = "==2.3.1" } +transformers = { version = "^4.39.0"} [tool.poetry.extras] From a8c02db493575a2cb95e7aeb20a586e71389f1c5 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 19 Jun 2024 11:32:50 +0000 Subject: [PATCH 06/17] add tests --- .gitignore | 1 + .../layers/attention/__init__.py | 5 +- .../models/flash_bert.py | 1 - router/src/lib.rs | 14 +-- tests/__init__.py | 0 tests/assets/default_bert.pt | 0 tests/assets/flash_bert.pt | 0 tests/conftest.py | 113 ++++++++++++++++++ tests/pytest.ini | 2 + tests/test_default_model.py | 28 +++++ tests/test_flash_bert.py | 28 +++++ 11 files changed, 183 insertions(+), 9 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/assets/default_bert.pt create mode 100644 tests/assets/flash_bert.pt create mode 100644 tests/conftest.py create mode 100644 tests/pytest.ini create mode 100644 tests/test_default_model.py create mode 100644 tests/test_flash_bert.py diff --git a/.gitignore b/.gitignore index ee44a963..6862c2f1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .idea target +__pycache__/ diff --git a/backends/python/server/text_embeddings_server/layers/attention/__init__.py b/backends/python/server/text_embeddings_server/layers/attention/__init__.py index 9cce5d34..42aac2bd 100644 --- a/backends/python/server/text_embeddings_server/layers/attention/__init__.py +++ b/backends/python/server/text_embeddings_server/layers/attention/__init__.py @@ -2,7 +2,10 @@ import os if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": - raise ImportError("`USE_FLASH_ATTENTION` is false.") + class Attention: + def __getattr__(self, name): + raise RuntimeError(f"TEI is used with USE_FLASH_ATTENTION=false, accessing `attention` is prohibited") + attention = Attention() if SYSTEM == "cuda": from .cuda import attention elif SYSTEM == "rocm": diff --git a/backends/python/server/text_embeddings_server/models/flash_bert.py b/backends/python/server/text_embeddings_server/models/flash_bert.py index 60be0002..6ebb70d4 100644 --- a/backends/python/server/text_embeddings_server/models/flash_bert.py +++ b/backends/python/server/text_embeddings_server/models/flash_bert.py @@ -233,6 +233,5 @@ def embed(self, batch: FlashBatch) -> List[Embedding]: ) for i in range(len(batch)) ] - else: raise NotImplementedError(f"Pooling {self.pooling_mode} is not implemented in the python backend") \ No newline at end of file diff --git a/router/src/lib.rs b/router/src/lib.rs index 03f8fc41..14f1dfb3 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -105,7 +105,7 @@ pub async fn run( serde_json::from_str(&config).context("Failed to parse `config.json`")?; // Set model type from config - let backend_model_type = get_backend_model_type(&config, &model_root, &pooling)?; + let (backend_model_type, inferred_pooling) = get_backend_model_type(&config, &model_root, &pooling)?; // Info model type let model_type = match &backend_model_type { @@ -191,7 +191,7 @@ pub async fn run( } }); - let pooling_str = match pooling { + let pooling_str = match inferred_pooling { Some(pool) => pool.to_string(), None => "none".to_string(), }; @@ -313,19 +313,19 @@ fn get_backend_model_type( config: &ModelConfig, model_root: &Path, pooling: &Option, -) -> Result { +) -> Result<(text_embeddings_backend::ModelType, Option)> { for arch in &config.architectures { if Some(text_embeddings_backend::Pool::Splade) == *pooling && arch.ends_with("MaskedLM") { - return Ok(text_embeddings_backend::ModelType::Embedding( + return Ok((text_embeddings_backend::ModelType::Embedding( text_embeddings_backend::Pool::Splade, - )); + ), Some(text_embeddings_backend::Pool::Splade))); } else if arch.ends_with("Classification") { if pooling.is_some() { tracing::warn!( "`--pooling` arg is set but model is a classifier. Ignoring `--pooling` arg." ); } - return Ok(text_embeddings_backend::ModelType::Classifier); + return Ok((text_embeddings_backend::ModelType::Classifier, None)); } } @@ -353,7 +353,7 @@ fn get_backend_model_type( } } }; - Ok(text_embeddings_backend::ModelType::Embedding(pool)) + Ok((text_embeddings_backend::ModelType::Embedding(pool.clone()), Some(pool))) } #[derive(Debug, Deserialize)] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/assets/default_bert.pt b/tests/assets/default_bert.pt new file mode 100644 index 00000000..e69de29b diff --git a/tests/assets/flash_bert.pt b/tests/assets/flash_bert.pt new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..6d8ed997 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,113 @@ +import pytest +import asyncio +import contextlib +import random +import os +import tempfile +import subprocess +import shutil +import sys +from typing import Optional +from aiohttp import ClientConnectorError, ClientOSError, ServerDisconnectedError +import requests +import time +from requests.exceptions import ConnectionError as RequestsConnectionError + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + +class ProcessLauncherHandle: + def __init__(self, process, port: int): + self.port = port + self.process = process + + def _inner_health(self) -> bool: + return self.process.poll() is None + + def health(self, timeout: int = 60): + assert timeout > 0 + for _ in range(timeout): + if not self._inner_health(): + raise RuntimeError("Launcher crashed") + + try: + url = f"http://0.0.0.0:{self.port}/health" + headers = {"Content-Type": "application/json"} + + response = requests.post(url, headers=headers) + return + except (ClientConnectorError, ClientOSError, ServerDisconnectedError, RequestsConnectionError) as e: + print("Connecting") + time.sleep(1) + raise RuntimeError("Health check failed") + +@pytest.fixture(scope="module") +def launcher(event_loop): + @contextlib.contextmanager + def local_launcher( + model_id: str, + trust_remote_code: bool = False, + use_flash_attention: bool = True, + dtype: Optional[str] = None, + revision: Optional[str] = None, + pooling: Optional[str] = None, + ): + port = random.randint(8000, 10_000) + shard_uds_path = ( + f"/tmp/tei-tests-{model_id.split('/')[-1]}-server" + ) + + args = [ + "text-embeddings-router", + "--model-id", + model_id, + "--port", + str(port), + "--uds-path", + shard_uds_path, + ] + + env = os.environ + + if dtype is not None: + args.append("--dtype") + args.append(dtype) + if revision is not None: + args.append("--revision") + args.append(revision) + if trust_remote_code: + args.append("--trust-remote-code") + if pooling: + args.append("--pooling") + args.append(str(max_input_length)) + + env["LOG_LEVEL"] = "debug" + + if not use_flash_attention: + env["USE_FLASH_ATTENTION"] = "false" + + with tempfile.TemporaryFile("w+") as tmp: + # We'll output stdout/stderr to a temporary file. Using a pipe + # cause the process to block until stdout is read. + print("call subprocess.Popen, with args", args) + with subprocess.Popen( + args, + stdout=tmp, + stderr=subprocess.STDOUT, + env=env, + ) as process: + yield ProcessLauncherHandle(process, port) + + process.terminate() + process.wait(60) + + tmp.seek(0) + shutil.copyfileobj(tmp, sys.stderr) + + if not use_flash_attention: + del env["USE_FLASH_ATTENTION"] + + return local_launcher \ No newline at end of file diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 00000000..2f4c80e3 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode = auto diff --git a/tests/test_default_model.py b/tests/test_default_model.py new file mode 100644 index 00000000..f8ab25fa --- /dev/null +++ b/tests/test_default_model.py @@ -0,0 +1,28 @@ +import pytest +import requests +import json +import torch + +@pytest.fixture(scope="module") +def default_model_handle(launcher): + with launcher("sentence-transformers/all-MiniLM-L6-v2", use_flash_attention=False) as handle: + yield handle + +@pytest.fixture(scope="module") +async def default_model(default_model_handle): + default_model_handle.health(300) + return default_model_handle + +@pytest.mark.asyncio +@pytest.mark.private +async def test_single_query(default_model): + url = f"http://0.0.0.0:{default_model.port}/embed" + data = {"inputs": "What is Deep Learning?"} + headers = {"Content-Type": "application/json"} + + response = requests.post(url, json=data, headers=headers) + + embedding = torch.Tensor(json.loads(response.text)) + # reference_embedding = torch.load("assets/default_model.pt") + + # assert torch.allclose(embedding, reference_embedding) \ No newline at end of file diff --git a/tests/test_flash_bert.py b/tests/test_flash_bert.py new file mode 100644 index 00000000..38df22e3 --- /dev/null +++ b/tests/test_flash_bert.py @@ -0,0 +1,28 @@ +import pytest +import requests +import json +import torch + +@pytest.fixture(scope="module") +def default_model_handle(launcher): + with launcher("sentence-transformers/all-MiniLM-L6-v2", use_flash_attention=True) as handle: + yield handle + +@pytest.fixture(scope="module") +async def default_model(default_model_handle): + default_model_handle.health(300) + return default_model_handle + +@pytest.mark.asyncio +@pytest.mark.private +async def test_single_query(default_model): + url = f"http://0.0.0.0:{default_model.port}/embed" + data = {"inputs": "What is Deep Learning?"} + headers = {"Content-Type": "application/json"} + + response = requests.post(url, json=data, headers=headers) + + embedding = torch.Tensor(json.loads(response.text)) + # reference_embedding = torch.load("assets/default_model.pt") + + # assert torch.allclose(embedding, reference_embedding) \ No newline at end of file From 35cc5b8c538c9be536ec451d4d379b086d845913 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 19 Jun 2024 12:02:54 +0000 Subject: [PATCH 07/17] tests instructions --- tests/README.md | 11 +++++++++++ tests/requirements.txt | 3 +++ 2 files changed, 14 insertions(+) create mode 100644 tests/README.md create mode 100644 tests/requirements.txt diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 00000000..e5492ef9 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,11 @@ +## Testing + +To run the tests, install from within docker with `--entrypoint "/bin/bash"` the requirements +``` +pip install -r requirements.txt +``` + +and mounting a volume for the tests, they can be run from within the container with +``` +pytest tests/ -s -vvvvv +``` \ No newline at end of file diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 00000000..b1ee0f58 --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,3 @@ +pytest +pytest-asyncio +aiohttp \ No newline at end of file From 309d25560bb41acf9d3c960cd21012bf2009cfea Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 19 Jun 2024 12:10:19 +0000 Subject: [PATCH 08/17] add rocm image builder --- .github/workflows/build_rocm.yaml | 134 ++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 .github/workflows/build_rocm.yaml diff --git a/.github/workflows/build_rocm.yaml b/.github/workflows/build_rocm.yaml new file mode 100644 index 00000000..8a9fde49 --- /dev/null +++ b/.github/workflows/build_rocm.yaml @@ -0,0 +1,134 @@ + name: Build and push AMD ROCm docker image to registry + + on: + workflow_dispatch: + push: + branches: + - 'main' + tags: + - 'v*' + pull_request: + paths: + - ".github/workflows/build.yaml" +# - "integration-tests/**" + - "backends/**" + - "core/**" + - "router/**" + - "Cargo.lock" + - "rust-toolchain.toml" + - "Dockerfile" + branches: + - 'main' + + jobs: + build-and-push-image: + concurrency: + group: ${{ github.workflow }}-${{ github.job }}-rocm-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + runs-on: [self-hosted, intel-cpu, 32-cpu, 256-ram, ci] + permissions: + contents: write + packages: write + # This is used to complete the identity challenge + # with sigstore/fulcio when running outside of PRs. + id-token: write + security-events: write + steps: + - name: Checkout repository + uses: actions/checkout@v3 + - name: Initialize Docker Buildx + uses: docker/setup-buildx-action@v2.0.0 + with: + install: true + - name: Configure sccache + uses: actions/github-script@v6 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + - name: Inject slug/short variables + uses: rlespinasse/github-slug-action@v4.4.1 + - name: Tailscale + uses: huggingface/tailscale-action@v1 + with: + authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Login to GitHub Container Registry + if: github.event_name != 'pull_request' + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Login to internal Container Registry + uses: docker/login-action@v2.1.0 + with: + username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }} + password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }} + registry: registry.internal.huggingface.tech + - name: Extract metadata (tags, labels) for Docker + id: meta-rocm + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/text-embeddings-inference + ghcr.io/huggingface/text-embeddings-inference + flavor: | + latest=false + tags: | + type=semver,pattern=rocm-{{version}} + type=semver,pattern=rocm-{{major}}.{{minor}} + type=raw,value=rocm-latest + type=raw,value=rocm-sha-${{ env.GITHUB_SHA_SHORT }} + + - name: Build and push Docker image + id: build-and-push-rocm + uses: docker/build-push-action@v4 + with: + context: . + file: Dockerfile-cuda + push: ${{ github.event_name != 'pull_request' }} + platforms: 'linux/amd64' + build-args: | + SCCACHE_GHA_ENABLED=on + ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }} + ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }} + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + tags: ${{ steps.meta-rocm.outputs.tags }} + labels: ${{ steps.meta-rocm.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max + cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max + + - name: Extract metadata (tags, labels) for Docker + id: meta-rocm-grpc + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/text-embeddings-inference + ghcr.io/huggingface/text-embeddings-inference + flavor: | + latest=false + tags: | + type=semver,pattern=rocm-{{version}}-grpc + type=semver,pattern=rocm-{{major}}.{{minor}}-grpc + type=raw,value=rocm-latest-grpc + type=raw,value=rocm-sha-${{ env.GITHUB_SHA_SHORT }}-grpc + + - name: Build and push Docker image + id: build-and-push-rocm-grpc + uses: docker/build-push-action@v4 + with: + context: . + target: grpc + file: Dockerfile-cuda + push: ${{ github.event_name != 'pull_request' }} + platforms: 'linux/amd64' + build-args: | + SCCACHE_GHA_ENABLED=on + ACTIONS_CACHE_URL=${{ env.ACTIONS_CACHE_URL }} + ACTIONS_RUNTIME_TOKEN=${{ env.ACTIONS_RUNTIME_TOKEN }} + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }} + tags: ${{ steps.meta-rocm-grpc.outputs.tags }} + labels: ${{ steps.meta-rocm-grpc.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max From ae3da109356c830fd5b66bd31a88487ffc002521 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 19 Jun 2024 12:48:26 +0000 Subject: [PATCH 09/17] add reference tensors --- tests/README.md | 23 +++++++++++ tests/assets/default_bert.pt | 0 tests/assets/flash_bert.pt | 0 ...ence-transformers-all-MiniLM-L6-v2_inp1.pt | Bin 0 -> 3024 bytes ...sformers-all-MiniLM-L6-v2_inp1_no_flash.pt | Bin 0 -> 3069 bytes ...ence-transformers-all-MiniLM-L6-v2_inp3.pt | Bin 0 -> 6096 bytes ...sformers-all-MiniLM-L6-v2_inp3_no_flash.pt | Bin 0 -> 6141 bytes tests/collect.py | 37 ++++++++++++++++++ tests/test_default_model.py | 4 +- tests/test_flash_bert.py | 4 +- 10 files changed, 64 insertions(+), 4 deletions(-) delete mode 100644 tests/assets/default_bert.pt delete mode 100644 tests/assets/flash_bert.pt create mode 100644 tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1.pt create mode 100644 tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1_no_flash.pt create mode 100644 tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3.pt create mode 100644 tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3_no_flash.pt create mode 100644 tests/collect.py diff --git a/tests/README.md b/tests/README.md index e5492ef9..c4ff5d0b 100644 --- a/tests/README.md +++ b/tests/README.md @@ -8,4 +8,27 @@ pip install -r requirements.txt and mounting a volume for the tests, they can be run from within the container with ``` pytest tests/ -s -vvvvv +``` + +## Reference outputs + +For example, collecting the reference on an RTX 4090 on Candle backend: +``` +docker run --rm -it --gpus all --net host --entrypoint "/bin/bash" -v $(pwd):/tei ghcr.io/huggingface/text-embeddings-inference:89-1.2.3 +``` +and +``` +text-embeddings-router --model-id sentence-transformers/all-MiniLM-L6-v2 +``` + +and then +``` +python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 1 --flash +python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 3 --flash +``` + +Restart server with `USE_FLASH_ATTENTION=0`, and +``` +python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 1 +python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 3 ``` \ No newline at end of file diff --git a/tests/assets/default_bert.pt b/tests/assets/default_bert.pt deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/assets/flash_bert.pt b/tests/assets/flash_bert.pt deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1.pt b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1.pt new file mode 100644 index 0000000000000000000000000000000000000000..aaf95a92819bf701ce3a6e37c1c3870523216604 GIT binary patch literal 3024 zcmbVO2~<-_77aTD_XPzT&~9Y_34}e8dWac`N>D-Ck%1PBAz(BT5(46SupOioS3tA} z+=mf2Y-yE|Uk}9v7j#_OwUKVo)*b`~msSt1Qy**(>Tu35mGe`T{8x2vz4xl#jZm1G z3k25If=`;Cz*CT_PByBORcevZpiE9p)EQ>04XGlfRx28>N!Emq7lj9i(tYAJ$tm7m z2}+~VGi8=mE^`;e{Ld}ZRAtl|R5Lu|(~KHzsxaQ5o}Q-BCd4z`RGlF{-AAYwxVlDL zF@CAcQkkUIo3u|6%C$PBF{=Gfq24swgxQGoX3=J>SEZ!s&Bte&81)vB3W36e4VUUI z*|&oAR?Ac>omQ(>88y1(RN+KJg4&=?7_Ctmg?ekZNqQTjUN}i_>lvZ2vd&sz7P*lv z$UhOCI?G=W(j{k2pKDa>cv!iNbirGnx>c(&8nx;+qHx#Vl9A+}&WZX69J)l5kE)kg zF0<=|WnbN9T{{G(r3-mhiQxGJ&ObMME_TmD|Mi@P1mXJiZeR&yo9@^6q0$9 z9I?FUG)-La6t|~|xv#S_qzC;7oYq)X##t$Gj`0$Cksl3mT@G{SD-;T*9qh|ThhF$x>;Mq#Ahn*q~1C~NzUMLy^?||Lyd>nf{o}`B;$n!7m(sARR zQQc5OW4rCHaPv7!D#;6QDF}qa#ojchZX5R9Ud-E(i$!x1L}VcYNMQ3JJhO2n9@yFkzp7t}kB@G_mG^BAUn_USf0s_672104 zX7A@X+dT&Lg%hFAKV_W0&YJweO$iZ)&vL1u{pq<4e+uSunG9w0f8h!5>*>9k&9yxL3jr7;|GNahblK-Ztq^gFQom`ARyV-ymtj083iv_c?q$qX%}Y3W5c( zmLzew1^p5=*r@-RKKSV%jkBsJQSL{P%}4jd2$USS&-L;VW0gZS?SCc%H7^~(qn91r znsA?-eDD)#P0k{ZollUbuUCU@(-atEx)39eeGhvlMN+m0!jB)8m&4on-Gf^_!fANz zB56oLB(d96O9u2SM6+-gnpEjYB4T!sfSB&cavF0NUV(t!3-QTsF7VqK8(ceXLXe^N zYMklyH(-3xX~-9-m2`&<2U~;UN)Su-UTurZFPB6#ykS4({hym-g`*Q^LP~WLuBoqr zsG84!#VP1mO9?(dF%lqY9nQBuPgp#NXmk-55pIOAky^syPXs;7sbhEuj?Ex^zM=QQ z-gw*TEou676Z2&NE_SeBIlqUVkAFfd8*MN*C!16|scGCe2RKr8n6NX1d6UX}~z%|B3BK&h{g_rF)qFibN>yI7S~)?K5^twYF>r~=sJbpy}Ob^*qhF7usA_Epruu&OVmiOXF{ z6}Js`kNa0!e1qKoPTkhdmNpM6hL%woKyu6e!SLaigeg*Xo&du{kTzp6+4lA!N3S%K z8_v7%de|6P^0FLOhp5S+Q)vZjs6UK4Smf$ zw2NhSF(^4^AABUr4g}xzr3We(;J7_9wvXX(edhy?-QUFf@_e$fsFbq2fjvidK;){u zz?J@$^Lw00UI_RuH_WfY}2qs{=vW0ia^)W!l2+2MGxZFj6|DHP$3QJ3mck4QtiQXevyc_-AKjWn= hYZ*4}J)6f1U_r8N`)~ergu;TCdIAgf;qBXP{{yb>Qq2GW literal 0 HcmV?d00001 diff --git a/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1_no_flash.pt b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1_no_flash.pt new file mode 100644 index 0000000000000000000000000000000000000000..d986e33279c76e2005fdec89b65e617f1787e992 GIT binary patch literal 3069 zcmbtW2~<;88h&9XZi5Rh6?C*^5rSdIy!(TuEK?(hxFL!$U?9;n5E5{OVil_*id)Cl zQqgfoae=Dk-QTtDsL(oAMQbY#l*(w28%mwp`@r#tj;H4^_ne!TI`lA$r0Q_?a%)Ml7v8iPR=tvBjNM9W45$+G;^dSkk; z+L)$JF=))X!BaIBjZgY?L!|PrT*9A?T&rYDT1K+YN1bWW8_ZI5hBhfvZu(~M^6$c(Al4DHl#eX>Pra`22ZIa*B8I8z6o zSe3m)-csvPMeK#*jD2gcLb>4f1uJQmMVlt#M=JgQJX&u?)1bFl4B8f=P}|#%kre0K ziSnFN`-qBJ9W0Mjc52HqJ7+zrbq?6?gJPmv+{jOxokQo2E2qX)ztMHUeW>7kg*qPV zhmK8Ec(ZaHS?lV8(fv=;%oC6B_T@oB!@Nvh+K&Jh*&mYlYH+opnsl5S59KpQquZWy z^y4)a_-Kn10=gd|2TpsUZ^9?!wla|3vFwKEu}^96XQxP6T^?MY8-atvZh>sqYJ4(F zO*G3?WY@l0dSB5GXY{C~hnoLUVVr!Hd^M~Y&Tb5bhhDz)+}VwIdhcd3DY};EU2f63 z-`CLRVI#4~V<#*+dF-Ar{8VK?iT?s!hJbNWG4 zbXSo#zvz^a_?}z=-Z51;DLNUFRJpLKYn||0q89-{-!gVAxj0hJp?>P|CMFW+q&ImbOdrJ_82i7Nvu;V<9Sh1H}=v7URmQSL`ZOVjyO)I8x=PMx2su6P) z9F0q|!Kx1Vv~;rz_WrSwI_5_r6B_rp-l-xdDiSA$X5yyh-}AGoy`lN{6jTnYp~lI> zNyNz{tgTH4qvvPzl#85%2dtq_4ZTA@s^MYD{E@^UDaEEapFn}-CkU!4#-&4TsQ8YU za%8!U;%`JW3b$vZP_zFw(tTPPLHl1w{)jAL&eOl+=tu97hhBSWg3yEZoEQP`9{!rw zOz6YcPqL%s?+<|5S)Fldi2^#5+Y&aZVWS@3ss4c`?cYz;DL<0tz#~`@y%)2~V)01- zdZ8dpj@Rd&pmn;L_+07=?7a{lP)AsdNyN%L!g592yA;BshvT8qE*pz?mBEExqv(_H zQgSW7jOpqC%qi#;+E}fk?-b4BpSi^m76&o+_~@`3oxjM3RO_~qK^~ouX&P@WyaYFc zaxmB09d6V(;?oVo6?x&y@q^G!plY_kzVd;{*2063CdJhz!~%!!1?Dq)Hb_S6C+?$S z{CyMb@o%bW;8W6wq34dmx6k^*##DnM_WCy5UKIu`-ne<-d2(Sv4PiMDL_dZ5eMp{% zT@J$PIR!ipGMEZ@D8hH?*R&kDyD@7Z6}ioOVVuwYMbT-m)d?$>KEbHW``7{9B6d5|p4as&^9H_jOT zAAxhsQfT=@g8pYY_%~jnrA{88TWsKax~=9#PCLsxAk!*knq@jVOCFE@g6s*N4pAQ; zr&a&V!=ZVz>GZ3|sJI8S?kXJZwgC^E-VSkV7xNACi%81aYrx_{M4hd5Edu6CeE0E6 zq~?cwVrC3IF9cvt)Ja<4Fdk~gT*Jy*%r%Vwr)e4!G z$&M#Gkm+{F_I1l}i?jW{({-0=rF$4LDL;jdbFbjwFWrE7n-(P|5|0fxM6K{_AB3GX zpiRH0Xy%qPEo`A>u{z$ZeiN**o(Y->pC7z5U<%w?IEH6u0@(JDLE(gX#O2tp!tu>_ ziL&DkWL|(Rs)OL=rX}Zp)uMOBHM%dZ5ZKvA2PpQ@g&Vy1^)dILcTYEv`}IY3HVMjM z6R|ie0a-pmOZq!QL(gth)a_))aFX7~3z(M#mTO|YESt<f{ry` z2t}T`@~ zR7$u6d0=o*NMJ}vV6cBk;E*9fe!(GrayESBfg!SY)6I@YVw;M-tbZZB4ABl0@X zs|p}(fMQv1+pbN$(B11uuVzuz2I<@OBK>cD#j5P=yk7vgTFYKS*}nMp-|N5@A#vPH h$JQI{Ik8*j)^%w8Ew;t6DqFEHWCpP1#PzLf{{Z-0Ya{>w literal 0 HcmV?d00001 diff --git a/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3.pt b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3.pt new file mode 100644 index 0000000000000000000000000000000000000000..bea6dca1e82781e1446019cb8d5c37a1876e9a7b GIT binary patch literal 6096 zcmbVQ2{=~U*EeJ?rH~=fASFX7!r4orq@cm&3o3Wh`R(Q)GEXM)-yZVQs9)Pq6jhv_~Efm|>21%chR{ z#YMp3_p(r#afY0@H3}W#HWBak@pN`f0bCsZnGUkPNM3cu!kfe?c--JI#Fy>B^sDZq zkM<05wd65XpQ?##n`>#~hP;Xd>C5EC+%Itb@;;L=z%h`lEO1 zHEy0p2yv`WfsJa{vCPdI4Ch5bi_%lRvqhH#hG!6&M}heLL_4v#@S6Kva|te9Z>P@| z)<9L-4v2gm$<8tr7L-a`IG;(xd&%oa?_0{`nbc{t+ntOpiv96mcQUSP%tm&m@@XG5 zFlEwAdZ4F?pDf*lT|*XNM(8XkKR$)uzC)VS#CifdFMs)h5|zC_iIn&)X7K{Z=yDxk zTrk^f0yl7f`~{M?`wI1ae2TCf(cpn@bn?Ped}&2CO>3xtt#8{gpjL#cW%a^`H#XCe z!5aASN*(P#zy_JnL`wWx%VVxh!K(HM#>rhS(OC~}cPvJ;F%2};)Q-4+_Cni_%V67^ z-Sl$s7!vBalRkG+wOGHAgVk*_NvZq^A*b-oyBSzaJUqP@9h9L5;yK8;nanYc_^#!C zLjFi@XCYZ2_kv^wt>l#@cBArqZE|bhDcTjGNJlN50vDo6ImQdur6onH+J?aSjq;e* zU=C}?OA^aU32I&JkD8Ak(jfIx8nd#AEL&WKz1N;XB{v7W-O)75DhP|?St z_>3J@2wcZqcbP*{b~ca!@)-hFT4iBKWQ}r3#oRuOOIg&|brU+Dti?(Fbik)o2E8X* znV-<#ij$=dLN8kh92PYUnJsk7eKTJuS&1bB?(rSFvju*rey>6*@XzW7Z+7{^Bz%d# zeLV}ibT#1fz>VfX7xORye6D=dXXTm4!3h6M=a6#j2zx@)S;^~0_o$R%cRvklRVT3fS{0S z>UumD8#n$!-^ZMztUjoG&-JP-b= zhZsKX?M2EZjH;m~Z6IG9(&>KA7pN3ac7L&S4U(9=XpLTNFz{V+#4m@5+ues`S8WUf`wkadv zMD!pude_6}y2*D3Q+=nP#jDhe^Nnlu6N`0C=aYln3x?lDJ6K@N!J5xcLbhbe&tJjyL zu_C5Glt66?*47RH>(hJ5fSerqrMEcFb3X|pYp2rD##Ma2c09%}zD#>OOsJOBIjRw_ z4$AVIXj1L~$bxyuxC77piE#JZGT0>HcL-JwfCCYM0 zYu?zRsg=5^*(iT%=^9BTowtMfx>{`Pa>K}iTSR-CDLC9VL(lv(!0gIr8?G1l;P(Wl z(wz8Ev=krDnPxOqq;lf$ICC+0O|V7lOAo2ZZc$*`B8FoOxEH<5L9f~t;ulZnKk1Kv z!wS#IzAhPbYuQ0rUb*ibmBe8GA^7^v0o`2!Xxnxxc5mmNT5N@prRmUleh$4hSJvX! zm!)R6H>u&tQxB;9CE3bGy`dP?TNL~iQei@nJeGEA;(q1TV8uD%NT(K>V|WavAIPTx z(@IFL{%1NPCXUL7E~9#%63zXWx>T_I31`Gl`iJ?Lf@oM3y@9g$WVqN#U^;=hn-7rn zR)e@vPXdIT)A`yBoaOW~s;a*drYQ}vF#701W3#i!PK(Ec)d0B_AcN^TeJhzZF-_Ef z6b!qCH=c(6`1{LzF)X8~WeUM}-4R4)Yq6jgcu=7WwS~>(Zkr)6TVOm^_8*Xr>1Nl# z{?cX?cy)SJMp!M8IXd#NW9V!Yux|J$QF*MrkY?4T(bv18fbP#l0ndt+MIfc83AYN9 zF(zp}vU)&On|z*Ijp!S)R5%m0zT^f4J7+-2udbvq_F2XHuLJ3TFRiGt#}%Un=97bd zGll)+LB$c`{UsHRp6sG^mTlB0J`;Y~Vg!{lCIPE;d|W>Yx2E6cueyGvW2420e%uZ| z_Spd%-Sd%S9N}-_t4Yn9Uy)l4z;JOX*6;DRJVfRjc0K=v)mIIulx35HnF{damug;n z@LA?p6rgN?~#AqTKNK?}0KJf}UXg_t~iBZPJk z{?G<@@~w0bve&5KU=Nq1ZNa%q8b3DHa~G^Mkm(InZ;qu5JO8ZcDFi5nQPw;7zK4s5 zN9#MD=>Vn%NTEnT95GCI$5~h|LG~V=75Vc13IpgwgIxZ;{dxRNt%S_A*2E7pr{RL< zhpF{x4b&by30a*8xyRc1nZ&2LnN&HY5LW-Ncf18Kev$D*WNc4@z>Uz)Sic|x;$PE? zr{B@YCn})-Xg*Y%XQri^?gynSoWa z7Se-<=Yit6XjJW7#?C$nd*#FcuQ?h+R&{V1!w=wO&WW^?bn^uV#y~R2!`7+uVae$v z)c&eMZVo#}?PA^WUZXfvY+H-Jj@&NT!-^@YB*Q8kzgn?=2H`c&-;>ojU zjPJfyo;a)opI&od-I<>1b=FAN~TeKecC#@;poMsEm9+ycC0(-SodUjDBR+b&0HuDOo z5UYExDb~#}A!4!5iC1J_i~X6hIKk*O&p7?j|K(iWAo6)5X$UGI-)!~JzW)reXwvSA zjzSkqX{tBRZ|cLXKlGWOb+m>^>INap6E^G40lB3|!SO)}<@a3{FvF<=N{_zPgd^Tz zDC8q!PBb6rHSmYuBfjIPppQ=b?J!GE8<|Gvf;~6r(U2a(xP!iZPm(E@#ev!L3_sE0 z?~|@w>fo2N3VaegK=y$JZkXOfDh6a>Thuf1VAB?&BjH0{RY~){om%%n#nCMZKtF5xuWFJE1+brPg&0J=8ynr>${Wm z(i&<$)JYMNPI%yB6E`e&62pr_djr##MP2PYuIf5Rd{wqUsDd99&z(fFy5(r`+_Ri* z(*k69s<3UkPaNISs82!@Eos)L4KHt!1Q}Nl?|+Nx>dV5TR5hIAa0$~*^EkzGqWEFG z_>cZiUP%Fj=|j4SA7vbnsPVa+wc%w5SrtrJUBR#yafIz7ng+pi<@FjY^|->tg$=Jv zOYwrsXJw&Z<{(TmoMa~BaD&xw64+c%z>67&=-vlLC(?txwDwbpth2(|1#Ll_ zOkZNS$eb1oPoeEe;p9QIG;~RKkhU%3iAioXrk6_-(NDk7n)mL+{eC}W^+O*>p1@4E zWXiZAw;`5l5uQ)$Z%turZkQWK52*UX6RS32xjl)D z`0R{PMHlGjoK|Q)T8On4FX$m-aiV)HADaeWfRNRO_-6h-bQaH`I$IK%){@A|8?oFB z#bjWg!_iJRny{HrCA&aSoajOKoL8hBV?WXLx#ln}XetW5j`_o8y#dhbvzcRcf{PE7 z!{AeWsof?mQaPcT?b$(ncO6HD1-_QA1MhF)IITkm8UMV-@L}k@SrmlHlcVU1DI)#Z zmqB*@`{Zeyh~`gCo>s6?BOA}zSgHTZ=;_BWY@DospnjjP{?Yb<*(-pts@ z)MUa~lW}8A1(f}ETsbh3ArvkLJ5d4bpOe?Wjv2=b5B|3PKXcfBM)>P|(D-lO_1WO` zXCnMLy!>Z~zY4+VZy|2~WQgC#j{o%k*EEm&n|}wzpF6@I;`(QRzh=nlZvk%nWPtx& zI}Ub|l6rsm9W1W@X9S5qUj5JgqcF;5Yn9UPceYO$fQtkV{QbS~*uhRh81;%suphzx HciaB}S|>Pe literal 0 HcmV?d00001 diff --git a/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3_no_flash.pt b/tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp3_no_flash.pt new file mode 100644 index 0000000000000000000000000000000000000000..7bf5187911d6098a4ae44bd9cea3f5b5d8d0ff2b GIT binary patch literal 6141 zcmbtY3s_ER*M2*n6ggK2Nuole!+uthgmNfKN;;@idUbrG59uTnQAs0_)8vqdln(n@ zX;jlV53x!f+c>zK7NLhaPQ!V)gj?){lX&*y~Sd~g@M6=3l7S{eFwGtvJM6vh?>qDu+ogmtuRSgvl(IVP~Jv$S~Q3 zv9gjdIY)buz0AT`sjXr1?590DDKh+CsO}9ph8tAPqDRD{VU2st3$Y8ts?HxI6_Jv|IB44Y^LU15mj&d zNV`2WspEkA)LmsX*8Ow~-|osL(b6OF!{b`&T-Sqhv<&&$A;%RxqKDf)U{r_$yUh&!ze6do0MM=qg`zW zDU5i-F)Tn6I;hC08fsEA;F8@2mWwSkol~@&xNZl&EsZ80>V}deQ3>u#Pr*@{19AMl z6tpZljCrDQWkmxFNpj~+DVoWzZDZwlP7{WZi^=?0bkbe`x~-=-z6R?={dBHoQVLhTC6 zA;;hqx)zJ*!|k%@bn&?cSRQW>#TMXBY#1Q{`TxYQX?0B*o z8y`2(E!%_|_+aL%PzdPQM>n`mCcQa(>6UN}%e%S+Y|IytVuiC9SiKF-Z+ZwqZXwrY zsql;mJW^Iuu0i<~uc7Zx-}js#?RPJbpYmRi6F%$tsV(1Na-%MJ9(kSyuNXo@t>%DM zdl9#WAI4R_m#1m-$AV&v8miY@Kv0G}nKnwE2EGnN#uxpxvzXrev5gdZRN@$?^Z06- z16Ef(<=v&Hpr*?AblUwWw00T=Rh2`)>HZU<*VsxNBNNHv-fEKZb~h}xS`G)7M&s_# z3#`@-bd&yBvMl`qh)FPQKa-OUqPnimX1R2|+YKbd}h4XA^c0 zYkL!M;jRr#*HtukTM1nl+6asnY^*y4e&dq4lC}fH{&3S5y_e^3jFW@Fe2L*7 z)3G+Whhw^pyv zoC#!9%280A|N2XP9Y5w$Fyq`00d_kSf63b5prb6J*QuJucqGMZA@v`o4tUaH_rCdM3#b+uJWtHdLjF=?+ ztQZe_&EzrhmwS|7y$U7`xqwrrTqcXoY=hjS3q;-ZG~QS-1AiE!h&S`)xUjD6#NP5C zg{%x1pwIv+1!|=0O+18j?}QsGF4CSRZ?y4VfWyZ4z;ET2koLTsJZjku4QcgoZqbAC zmdhgYzF#`-tsMm2{c}k5g(K9xN*c4n&VdBy(O1VV@q^8`qoMItnzF)#GAuOxiWcmz z-bSC+tU&njh#nLU~27$Cq5&dxA8Jl-608J??_|Ovyo_$3m-*^Emk0@dBE6AiJ zYcZ#<7=s5Hz-LA2iV9yhFzKI-ZFwv3NwWePi+>{}n-#Fmtq|gVoK3^uZh_ePxiHv8 z2g}RMvG1Z0w_m%3Ub=XOSSI{xp{2coYD!h&NUWt8yiWd`f(!g)1#(| zZmD(PBhE&~O?h#Q3T6C4ud*^c+(B`;vzA5f@b5tAKhesH;M1ZC1>GqyAU~cQ)x1Kp zoU$n60k39B%nTd@sZ9SU%J>3@sB0Mi#skyj?-3Rof{&V`bL1sp_)9GkW9YA{LwPgR zJ+v||44))T=U9Eq%c5n#A#*hhlbMH1FElYj3QqfPCvw&XT+AE`fp3`OHILtJH6GZ0 za;{4W#qTocOy361duIi?T$Br;mmIL^Y!H3NTVq&D2bZX`8#eCGf`Rf*^nru2<(jHC z^9Co)FVAk+WrL-OqSt`sz)?!43>jOEphu9W$Tuvmn7fH1A2MZf+P1xGGos;eP6IPxxz zB;jA)zq5;u!SZ%}YPXUIe3R`rFObWzYhl5SiBKE;kXV*XhOMgMuy(afh1B>gJU#6e zbao}9z-zZ*8PN_oP71o&J7$C<-aZ`%tDEI2=G`o$8cyF(+l*M4ReS9su|4B`)EnPOKOpQ0P5U~A$RyJ=uz8`EFYY|shL;P`zX|# zXf55t9}fPOzJJ%v9d}xf=fxX{1b@J_#{lANE^^a?x>5EHqDgEDta-Q*HKeNng)u2aYxT=IYckV{!C;$8;IcH zoB4^9Q`rHXEo-r*O-s-OH+HcM*4~_si@$kEuDM$jjx}EmttmDTnWGP82S!6x%nRBv zxDc<7-U19CugboQiu6X}T?ydLsD*-#A*@Rg4~@9Z*;|gp?=qDEOQuqbe!Kau<|lA+ zX#`CgvyCUE7s!X6PTuCiAqetUKmlv2ZVzV`5P(Tz6oAF$oofeDrbC{aSA%?MF^SRB z#(?AVvHinw%JhYdPaF|u1gzG0OniWNoVN?jM+H2#`|EDmLw{rxpbB~9qu9L8Vd(90Rd<^@`xsWaTz5Ke8 zDNyXB2FqMrA^p@&^qrwW`kCd^bs3(Rl_E=S{ig>+nQzYK2YYLNQH5h(0Cc1{=h zx9w!(rbeP0{(`Fu`AE<}7T!y`=;H6TfPG;pgltOYkGp-Q7ufrz`%X5@f4CF0cZGt^ z=2v7**Cp^BG>_`;wSYX^512Vf0pvEySynvqB2HTkF>SIMaMBCW*|V91Uo%0)5NYVm zQwPNb=Xu+ocY~0h%r^F3P>U;%P0#_Cb^18^OeIJz72+z(Y&e>Bm;1b56}~TD2BG&$ z!P#>KShvoGJ)&}cT~i!ZHnoLI z44Wngfwk0Lmh*6nyJ>^LF@pUMo9ZFd9+@XVe!Z+E(>urXPKup&LD`4ZuxrFDqFbs$ zty`HlU6)xqfa%B=!{YIv^+~~Dy9xIm}fe+#^VD%QBNz_N^VySc>FX)@`j@N3A(2qO5 z6LQBZH;xJ3CF9e~>HWks%Ct$!D-?lwoIEt1PDYhgp^Tg&sSDXi4HkKk-DLyNG4nj* zyuU z_y{m z-_y`@Pa&(l5KH7=(n! z)`Rt3juxwlq5ZQ5T|IFqW&E(dZUITt=b~OpDeRY#L7VpjAwxTvV?IGCaTzd=(>IT_ z$+`Jel0z3DJG^f%uLM8OiZVlnV6f;Fg7(ap2Ci)?1G80vAMB{nW_1d>81io z-JcazCNPA;aY!2U6L{qvadOJfx-{&D=j=CFSc_2)sM+bK7{#-@-zeKw8wUPdBeL2|6%NzVV zz$h8RKSIg<`_uoPb-oBmB#IjH`vrENFfA8l5BvRhVJz-oFDJ}}nE~WP!v5d){tw(5 BMB)Gd literal 0 HcmV?d00001 diff --git a/tests/collect.py b/tests/collect.py new file mode 100644 index 00000000..313c0871 --- /dev/null +++ b/tests/collect.py @@ -0,0 +1,37 @@ + +import requests +import torch +import argparse +import json +import os + +parser = argparse.ArgumentParser(description='Assets collection') +parser.add_argument('--model-id', help='Model id', required=True) +parser.add_argument('--n_inp', help='Number of inputs', required=True, type=int) +parser.add_argument('--flash', action='store_true') + +args = parser.parse_args() + +url = f"http://0.0.0.0:80/embed" + +INPUTS = [ + "What is Deep Learning?", + "Today I am in Paris and I would like to", + "Paris weather is", + "Great job" +] + +data = {"inputs": INPUTS[:args.n_inp]} +headers = {"Content-Type": "application/json"} + +response = requests.post(url, json=data, headers=headers) + +embedding = torch.Tensor(json.loads(response.text)) + +postfix = "" +if not args.flash: + postfix = "_no_flash" + +save_path = f"./assets/{args.model_id.replace('/', '-')}_inp{args.n_inp}{postfix}.pt" +print(f"Saving embedding of shape {embedding.shape} to {save_path}") +torch.save(embedding, save_path) \ No newline at end of file diff --git a/tests/test_default_model.py b/tests/test_default_model.py index f8ab25fa..595fe6bf 100644 --- a/tests/test_default_model.py +++ b/tests/test_default_model.py @@ -23,6 +23,6 @@ async def test_single_query(default_model): response = requests.post(url, json=data, headers=headers) embedding = torch.Tensor(json.loads(response.text)) - # reference_embedding = torch.load("assets/default_model.pt") + reference_embedding = torch.load("./tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1_no_flash.pt") - # assert torch.allclose(embedding, reference_embedding) \ No newline at end of file + assert torch.allclose(embedding, reference_embedding, atol=1e-3, rtol=1e-3) \ No newline at end of file diff --git a/tests/test_flash_bert.py b/tests/test_flash_bert.py index 38df22e3..3c3fde1c 100644 --- a/tests/test_flash_bert.py +++ b/tests/test_flash_bert.py @@ -23,6 +23,6 @@ async def test_single_query(default_model): response = requests.post(url, json=data, headers=headers) embedding = torch.Tensor(json.loads(response.text)) - # reference_embedding = torch.load("assets/default_model.pt") + reference_embedding = torch.load("./tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1.pt") - # assert torch.allclose(embedding, reference_embedding) \ No newline at end of file + assert torch.allclose(embedding, reference_embedding, atol=1e-3, rtol=1e-3) \ No newline at end of file From 7bba462cbed671002e44570abed049e744b99590 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 19 Jun 2024 15:14:51 +0000 Subject: [PATCH 10/17] update doc --- docs/source/en/_toctree.yml | 6 ++- docs/source/en/local_amd_gpu.md | 40 +++++++++++++++++++ .../en/{local_gpu.md => local_nvidia_gpu.md} | 4 +- 3 files changed, 46 insertions(+), 4 deletions(-) create mode 100644 docs/source/en/local_amd_gpu.md rename docs/source/en/{local_gpu.md => local_nvidia_gpu.md} (96%) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 211d1ca5..b8fbb6f9 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -11,8 +11,10 @@ title: Using TEI locally with CPU - local: local_metal title: Using TEI locally with Metal - - local: local_gpu - title: Using TEI locally with GPU + - local: local_nvidia_gpu + title: Using TEI locally with Nvidia GPU + - local: local_amd_gpu + title: Using TEI locally with AMD GPU - local: private_models title: Serving private and gated models # - local: tei_cli diff --git a/docs/source/en/local_amd_gpu.md b/docs/source/en/local_amd_gpu.md new file mode 100644 index 00000000..8dc8e2de --- /dev/null +++ b/docs/source/en/local_amd_gpu.md @@ -0,0 +1,40 @@ + + +# Using TEI locally with an AMD GPU + +Text-Embeddings-Inference supports the [AMD GPUs officially supporting ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html), including AMD Instinct MI210, MI250, MI300 and some of the AMD Radeon series GPUs. + +To leverage AMD GPUs, Text-Embeddings-Inference relies on its Python backend, and not on the [candle](https://github.com/huggingface/candle) backend that is used for CPU, Nvidia GPUs and Metal. The support in the python backend is more limited (Bert embeddings) but easily extendible. We welcome contributions to extend the supported models. + +## Usage through docker + +Using docker is the recommended approach. + +```bash +docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --net host \ + --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 32g \ + ghcr.io/huggingface/text-embeddings-inference:rocm-1.2.4 \ + --model-id sentence-transformers/all-MiniLM-L6-v2 +``` + +and + +```bash +curl 127.0.0.1:80/embed \ + -X POST -d '{"inputs":"What is Deep Learning?"}' \ + -H 'Content-Type: application/json' +``` \ No newline at end of file diff --git a/docs/source/en/local_gpu.md b/docs/source/en/local_nvidia_gpu.md similarity index 96% rename from docs/source/en/local_gpu.md rename to docs/source/en/local_nvidia_gpu.md index 7b76300a..f2e71cfd 100644 --- a/docs/source/en/local_gpu.md +++ b/docs/source/en/local_nvidia_gpu.md @@ -14,9 +14,9 @@ rendered properly in your Markdown viewer. --> -# Using TEI locally with GPU +# Using TEI locally with Nvidia GPU -You can install `text-embeddings-inference` locally to run it on your own machine with a GPU. +You can install `text-embeddings-inference` locally to run it on your own machine with an Nvidia GPU. To make sure that your hardware is supported, check out the [Supported models and hardware](supported_models) page. ## Step 1: CUDA and NVIDIA drivers From d87b91de13ba6d7ed8a69263cf0e1715c09a1aa0 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 20 Jun 2024 08:17:25 +0000 Subject: [PATCH 11/17] add config-inline --- .github/workflows/build_80.yaml | 3 +++ Dockerfile | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build_80.yaml b/.github/workflows/build_80.yaml index fcee1784..c032cb00 100644 --- a/.github/workflows/build_80.yaml +++ b/.github/workflows/build_80.yaml @@ -40,6 +40,9 @@ uses: docker/setup-buildx-action@v2.0.0 with: install: true + config-inline: | + [registry."docker.io"] + mirrors = ["registry.github-runners.huggingface.tech"] - name: Configure sccache uses: actions/github-script@v6 with: diff --git a/Dockerfile b/Dockerfile index 625d165f..c84baa1b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,7 +4,7 @@ WORKDIR /usr/src ENV SCCACHE=0.5.4 ENV RUSTC_WRAPPER=/usr/local/bin/sccache -# Donwload and configure sccache +# Donwload & configure sccache RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \ chmod +x /usr/local/bin/sccache From 9941fcc8965858988043650df920d73a151d112a Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 21 Jun 2024 13:58:43 +0000 Subject: [PATCH 12/17] style --- .github/workflows/build_rocm.yaml | 6 +- backends/candle/src/lib.rs | 54 ++++---- backends/candle/src/models.rs | 6 +- backends/candle/src/models/flash_jina.rs | 2 +- backends/candle/src/models/flash_jina_code.rs | 31 +++-- backends/candle/src/models/jina.rs | 1 - backends/candle/src/models/jina_code.rs | 22 +++- .../layers/attention/rocm.py | 2 +- .../layers/layernorm.py | 4 +- .../text_embeddings_server/layers/pooling.py | 4 +- .../models/flash_bert.py | 2 +- backends/python/src/lib.rs | 118 ------------------ backends/src/lib.rs | 4 - docs/source/en/local_amd_gpu.md | 2 +- router/src/lib.rs | 26 ++-- tests/README.md | 2 +- tests/collect.py | 2 +- tests/conftest.py | 6 +- tests/requirements.txt | 2 +- tests/test_default_model.py | 2 +- tests/test_flash_bert.py | 2 +- 21 files changed, 103 insertions(+), 197 deletions(-) diff --git a/.github/workflows/build_rocm.yaml b/.github/workflows/build_rocm.yaml index 8a9fde49..5971ed78 100644 --- a/.github/workflows/build_rocm.yaml +++ b/.github/workflows/build_rocm.yaml @@ -79,7 +79,7 @@ type=semver,pattern=rocm-{{major}}.{{minor}} type=raw,value=rocm-latest type=raw,value=rocm-sha-${{ env.GITHUB_SHA_SHORT }} - + - name: Build and push Docker image id: build-and-push-rocm uses: docker/build-push-action@v4 @@ -98,7 +98,7 @@ labels: ${{ steps.meta-rocm.outputs.labels }} cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/text-embeddings-inference:cache-rocm,mode=max - + - name: Extract metadata (tags, labels) for Docker id: meta-rocm-grpc uses: docker/metadata-action@v4.3.0 @@ -113,7 +113,7 @@ type=semver,pattern=rocm-{{major}}.{{minor}}-grpc type=raw,value=rocm-latest-grpc type=raw,value=rocm-sha-${{ env.GITHUB_SHA_SHORT }}-grpc - + - name: Build and push Docker image id: build-and-push-rocm-grpc uses: docker/build-push-action@v4 diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index 27c5d843..4af5b704 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -11,17 +11,17 @@ use crate::compute_cap::{ compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap, }; use crate::models::{ - BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaConfig, JinaBertModel, JinaCodeConfig, JinaCodeBertModel, - Model, NomicBertModel, NomicConfig, + BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, JinaCodeBertModel, + JinaCodeConfig, JinaConfig, Model, NomicBertModel, NomicConfig, }; #[cfg(feature = "cuda")] use crate::models::{ - FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashJinaCodeBertModel, FlashNomicBertModel, + FlashBertModel, FlashDistilBertModel, FlashJinaBertModel, FlashJinaCodeBertModel, + FlashNomicBertModel, }; use anyhow::Context; use candle::{DType, Device}; use candle_nn::VarBuilder; -use models::BertConfig; use nohash_hasher::BuildNoHashHasher; use serde::Deserialize; use std::collections::HashMap; @@ -133,7 +133,9 @@ impl CandleBackend { } (Config::JinaCodeBert(config), Device::Cpu | Device::Metal(_)) => { tracing::info!("Starting JinaCodeBertModel model on {:?}", device); - Ok(Box::new(JinaCodeBertModel::load(vb, &config, model_type).s()?)) + Ok(Box::new( + JinaCodeBertModel::load(vb, &config, model_type).s()?, + )) } ( Config::XlmRoberta(config) | Config::Camembert(config) | Config::Roberta(config), @@ -171,8 +173,9 @@ impl CandleBackend { Ok(Box::new(BertModel::load(vb, &config, model_type).s()?)) } } - #[cfg(feature = "cuda")] - (Config::JinaBert(config), Device::Cuda(_)) => { + } + #[cfg(feature = "cuda")] + (Config::JinaBert(config), Device::Cuda(_)) => { if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) && dtype == DType::F16 && ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi)) @@ -181,25 +184,32 @@ impl CandleBackend { && &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true" { tracing::info!("Starting FlashJinaBertModel model on {:?}", device); - Ok(Box::new(FlashJinaBertModel::load(vb, &config, model_type).s()?,)) + Ok(Box::new( + FlashJinaBertModel::load(vb, &config, model_type).s()?, + )) } else { tracing::info!("Starting JinaBertModel model on {:?}", device); Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?)) } - #[cfg(feature = "cuda")] - (Config::JinaCodeBert(config), Device::Cuda(_)) => { - if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) - && dtype == DType::F16 - && ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi)) - // Allow disabling because of flash attention v1 precision problems - // See: https://github.com/huggingface/text-embeddings-inference/issues/37 - && &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true" - { - tracing::info!("Starting FlashJinaCodeBertModel model on {:?}", device); - Ok(Box::new(FlashJinaCodeBertModel::load(vb, &config, model_type).s()?,)) - } else { - tracing::info!("Starting JinaCodeBertModel model on {:?}", device); - Ok(Box::new(JinaCodeBertModel::load(vb, &config, model_type).s()?)) + } + #[cfg(feature = "cuda")] + (Config::JinaCodeBert(config), Device::Cuda(_)) => { + if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1")) + && dtype == DType::F16 + && ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi)) + // Allow disabling because of flash attention v1 precision problems + // See: https://github.com/huggingface/text-embeddings-inference/issues/37 + && &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true" + { + tracing::info!("Starting FlashJinaCodeBertModel model on {:?}", device); + Ok(Box::new( + FlashJinaCodeBertModel::load(vb, &config, model_type).s()?, + )) + } else { + tracing::info!("Starting JinaCodeBertModel model on {:?}", device); + Ok(Box::new( + JinaCodeBertModel::load(vb, &config, model_type).s()?, + )) } } #[cfg(feature = "cuda")] diff --git a/backends/candle/src/models.rs b/backends/candle/src/models.rs index 7f098cfb..d64f5a35 100644 --- a/backends/candle/src/models.rs +++ b/backends/candle/src/models.rs @@ -7,6 +7,7 @@ extern crate accelerate_src; mod bert; mod distilbert; mod jina; +mod jina_code; mod nomic; #[cfg(feature = "cuda")] @@ -27,8 +28,8 @@ mod flash_distilbert; pub use bert::{BertConfig, BertModel, PositionEmbeddingType}; use candle::{Result, Tensor}; pub use distilbert::{DistilBertConfig, DistilBertModel}; -pub use jina::{JinaConfig, JinaBertModel}; -pub use jina_code::{JinaCodeConfig, JinaCodeBertModel}; +pub use jina::{JinaBertModel, JinaConfig}; +pub use jina_code::{JinaCodeBertModel, JinaCodeConfig}; pub use nomic::{NomicBertModel, NomicConfig}; use text_embeddings_backend_core::Batch; @@ -41,7 +42,6 @@ pub use flash_jina::FlashJinaBertModel; #[cfg(feature = "cuda")] pub use flash_jina_code::FlashJinaCodeBertModel; - #[cfg(feature = "cuda")] pub use flash_nomic::FlashNomicBertModel; diff --git a/backends/candle/src/models/flash_jina.rs b/backends/candle/src/models/flash_jina.rs index e128252a..0e1d3006 100644 --- a/backends/candle/src/models/flash_jina.rs +++ b/backends/candle/src/models/flash_jina.rs @@ -2,8 +2,8 @@ use crate::alibi::alibi_head_slopes; use crate::flash_attn::flash_attn_varlen; use crate::layers::{HiddenAct, LayerNorm, Linear}; use crate::models::bert::PositionEmbeddingType; -use crate::models::jina::{JinaConfig, BertEmbeddings}; use crate::models::jina::BertEmbeddings; +use crate::models::jina::{BertEmbeddings, JinaConfig}; use crate::models::Model; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::VarBuilder; diff --git a/backends/candle/src/models/flash_jina_code.rs b/backends/candle/src/models/flash_jina_code.rs index 97ca5fc0..0779c5d7 100644 --- a/backends/candle/src/models/flash_jina_code.rs +++ b/backends/candle/src/models/flash_jina_code.rs @@ -2,7 +2,7 @@ use crate::alibi::alibi_head_slopes; use crate::flash_attn::flash_attn_varlen; use crate::layers::{HiddenAct, LayerNorm, Linear}; use crate::models::bert::PositionEmbeddingType; -use crate::models::jina::{JinaCodeConfig, BertEmbeddings}; +use crate::models::jina::{BertEmbeddings, JinaCodeConfig}; use crate::models::Model; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::VarBuilder; @@ -28,7 +28,11 @@ struct AlibiBertAttention { } impl AlibiBertAttention { - pub fn load(vb: VarBuilder, config: &JinaCodeConfig, alibi_slopes: Option) -> Result { + pub fn load( + vb: VarBuilder, + config: &JinaCodeConfig, + alibi_slopes: Option, + ) -> Result { let attention_head_size = config.hidden_size / config.num_attention_heads; let all_head_size = config.num_attention_heads * attention_head_size; let hidden_size = config.hidden_size; @@ -116,9 +120,15 @@ impl AlibiBertAttention { new_qkv_shape.push(self.num_attention_heads); new_qkv_shape.push(self.attention_head_size); - let query_layer = query_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; - let key_layer = key_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; - let value_layer = value_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; + let query_layer = query_layer + .reshape(new_qkv_shape.as_slice())? + .transpose(1, 2)?; + let key_layer = key_layer + .reshape(new_qkv_shape.as_slice())? + .transpose(1, 2)?; + let value_layer = value_layer + .reshape(new_qkv_shape.as_slice())? + .transpose(1, 2)?; let attention = flash_attn_varlen( query_layer, @@ -135,7 +145,9 @@ impl AlibiBertAttention { let attention = attention.flatten_from(candle::D::Minus2)?; let hidden_states = self.dense.forward(&attention)?; - let hidden_states = self.layer_norm_out.forward(&hidden_states, Some(&residual))?; + let hidden_states = self + .layer_norm_out + .forward(&hidden_states, Some(&residual))?; Ok(hidden_states) } @@ -168,7 +180,10 @@ impl JinaBertLayer { .pp("mlp") .pp("down_layer") .get((config.hidden_size, config.intermediate_size), "weight")?; - let down_bias = vb.pp("mlp").pp("down_layer").get(config.hidden_size, "bias")?; + let down_bias = vb + .pp("mlp") + .pp("down_layer") + .get(config.hidden_size, "bias")?; let down_layer = Linear::new(down_weight, Some(down_bias), None); let layer_norm_1 = LayerNorm::load( @@ -455,4 +470,4 @@ impl Model for FlashJinaCodeBertModel { fn embed(&self, batch: Batch) -> Result<(Option, Option)> { self.forward(batch) } -} \ No newline at end of file +} diff --git a/backends/candle/src/models/jina.rs b/backends/candle/src/models/jina.rs index 3f5d5916..b1d75d94 100644 --- a/backends/candle/src/models/jina.rs +++ b/backends/candle/src/models/jina.rs @@ -30,7 +30,6 @@ pub struct JinaConfig { pub id2label: Option>, } - #[derive(Debug)] pub struct BertEmbeddings { word_embeddings: Embedding, diff --git a/backends/candle/src/models/jina_code.rs b/backends/candle/src/models/jina_code.rs index cb4084d7..ec8a8c84 100644 --- a/backends/candle/src/models/jina_code.rs +++ b/backends/candle/src/models/jina_code.rs @@ -30,7 +30,6 @@ pub struct JinaCodeConfig { pub id2label: Option>, } - #[derive(Debug)] pub struct BertEmbeddings { word_embeddings: Embedding, @@ -201,9 +200,15 @@ impl BertAttention { new_qkv_shape.push(self.num_attention_heads); new_qkv_shape.push(self.attention_head_size); - let query_layer = query_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; - let key_layer = key_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; - let value_layer = value_layer.reshape(new_qkv_shape.as_slice())?.transpose(1, 2)?; + let query_layer = query_layer + .reshape(new_qkv_shape.as_slice())? + .transpose(1, 2)?; + let key_layer = key_layer + .reshape(new_qkv_shape.as_slice())? + .transpose(1, 2)?; + let value_layer = value_layer + .reshape(new_qkv_shape.as_slice())? + .transpose(1, 2)?; #[allow(unused_variables)] let context_layer = if let (Device::Cuda(_), Some(cublaslt)) = @@ -276,7 +281,9 @@ impl BertAttention { let context_layer = context_layer.transpose(1, 2)?.flatten_from(D::Minus2)?; let hidden_states = self.dense.forward(&context_layer)?; - let hidden_states = self.layer_norm_out.forward(&hidden_states, Some(&residual))?; + let hidden_states = self + .layer_norm_out + .forward(&hidden_states, Some(&residual))?; Ok(hidden_states) } @@ -309,7 +316,10 @@ impl JinaCodeBertLayer { .pp("mlp") .pp("down_layer") .get((config.hidden_size, config.intermediate_size), "weight")?; - let down_bias = vb.pp("mlp").pp("down_layer").get(config.hidden_size, "bias")?; + let down_bias = vb + .pp("mlp") + .pp("down_layer") + .get(config.hidden_size, "bias")?; let down_layer = Linear::new(down_weight, Some(down_bias), None); let layer_norm_1 = LayerNorm::load( diff --git a/backends/python/server/text_embeddings_server/layers/attention/rocm.py b/backends/python/server/text_embeddings_server/layers/attention/rocm.py index 365e5451..9ed9004c 100644 --- a/backends/python/server/text_embeddings_server/layers/attention/rocm.py +++ b/backends/python/server/text_embeddings_server/layers/attention/rocm.py @@ -42,4 +42,4 @@ def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False): is_causal, False, None, - ) \ No newline at end of file + ) diff --git a/backends/python/server/text_embeddings_server/layers/layernorm.py b/backends/python/server/text_embeddings_server/layers/layernorm.py index abd9e676..0834b734 100644 --- a/backends/python/server/text_embeddings_server/layers/layernorm.py +++ b/backends/python/server/text_embeddings_server/layers/layernorm.py @@ -41,7 +41,7 @@ def __init__(self, prefix, handle, device, dtype, config: BertConfig): self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device) self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device) self.variance_epsilon = config.layer_norm_eps - + def forward(self, hidden_states, residual=None): if residual is not None: hidden_states += residual @@ -51,4 +51,4 @@ def forward(self, hidden_states, residual=None): return hidden_states, residual else: - raise ValueError("System not recognized") \ No newline at end of file + raise ValueError("System not recognized") diff --git a/backends/python/server/text_embeddings_server/layers/pooling.py b/backends/python/server/text_embeddings_server/layers/pooling.py index 1bccbc57..7eaddb6b 100644 --- a/backends/python/server/text_embeddings_server/layers/pooling.py +++ b/backends/python/server/text_embeddings_server/layers/pooling.py @@ -16,7 +16,7 @@ def mean_pooling(embedding, cu_seqlens, max_s): indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() embedding_padded = pad_input(embedding, indices, batch_size, max_s) - + sum_embeddings = torch.sum(embedding_padded, 1) - return sum_embeddings / seqlens[:, None] \ No newline at end of file + return sum_embeddings / seqlens[:, None] diff --git a/backends/python/server/text_embeddings_server/models/flash_bert.py b/backends/python/server/text_embeddings_server/models/flash_bert.py index 6ebb70d4..40003013 100644 --- a/backends/python/server/text_embeddings_server/models/flash_bert.py +++ b/backends/python/server/text_embeddings_server/models/flash_bert.py @@ -234,4 +234,4 @@ def embed(self, batch: FlashBatch) -> List[Embedding]: for i in range(len(batch)) ] else: - raise NotImplementedError(f"Pooling {self.pooling_mode} is not implemented in the python backend") \ No newline at end of file + raise NotImplementedError(f"Pooling {self.pooling_mode} is not implemented in the python backend") diff --git a/backends/python/src/lib.rs b/backends/python/src/lib.rs index ef33b7d2..8b137891 100644 --- a/backends/python/src/lib.rs +++ b/backends/python/src/lib.rs @@ -1,119 +1 @@ -mod logging; -mod management; -use backend_grpc_client::Client; -use nohash_hasher::BuildNoHashHasher; -use std::collections::HashMap; -use text_embeddings_backend_core::{ - Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions, -}; -use tokio::runtime::Runtime; - -pub struct PythonBackend { - _backend_process: management::BackendProcess, - tokio_runtime: Runtime, - backend_client: Client, -} - -impl PythonBackend { - pub fn new( - model_path: String, - dtype: String, - model_type: ModelType, - uds_path: String, - otlp_endpoint: Option, - otlp_service_name: String, - pooling_mode: String, - ) -> Result { - match model_type { - ModelType::Classifier => { - return Err(BackendError::Start( - "`classifier` model type is not supported".to_string(), - )) - } - ModelType::Embedding(pool) => { - if pool != Pool::Cls && pool != Pool::Mean { - return Err(BackendError::Start(format!("{pool:?} is not supported in the TEI Python backend. Please open an issue."))); - } - pool - } - }; - - let backend_process = management::BackendProcess::new( - model_path, - dtype, - &uds_path, - otlp_endpoint, - otlp_service_name, - pooling_mode, - )?; - let tokio_runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .map_err(|err| BackendError::Start(format!("Could not start Tokio runtime: {err}")))?; - - let backend_client = tokio_runtime - .block_on(Client::connect_uds(uds_path)) - .map_err(|err| { - BackendError::Start(format!("Could not connect to backend process: {err}")) - })?; - - Ok(Self { - _backend_process: backend_process, - tokio_runtime, - backend_client, - }) - } -} - -impl Backend for PythonBackend { - fn health(&self) -> Result<(), BackendError> { - if self - .tokio_runtime - .block_on(self.backend_client.clone().health()) - .is_err() - { - return Err(BackendError::Unhealthy); - } - Ok(()) - } - - fn is_padded(&self) -> bool { - false - } - - fn embed(&self, batch: Batch) -> Result { - if !batch.raw_indices.is_empty() { - return Err(BackendError::Inference( - "raw embeddings are not supported for the Python backend.".to_string(), - )); - } - let batch_size = batch.len(); - - let results = self - .tokio_runtime - .block_on(self.backend_client.clone().embed( - batch.input_ids, - batch.token_type_ids, - batch.position_ids, - batch.cumulative_seq_lengths, - batch.max_length, - )) - .map_err(|err| BackendError::Inference(err.to_string()))?; - let pooled_embeddings: Vec> = results.into_iter().map(|r| r.values).collect(); - - let mut embeddings = - HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); - for (i, e) in pooled_embeddings.into_iter().enumerate() { - embeddings.insert(i, Embedding::Pooled(e)); - } - - Ok(embeddings) - } - - fn predict(&self, _batch: Batch) -> Result { - Err(BackendError::Inference( - "`predict` is not implemented".to_string(), - )) - } -} diff --git a/backends/src/lib.rs b/backends/src/lib.rs index db27cddc..d332b4a7 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -39,7 +39,6 @@ impl Backend { uds_path: String, otlp_endpoint: Option, otlp_service_name: String, - pooling_mode: String, ) -> Result { let (backend_sender, backend_receiver) = mpsc::unbounded_channel(); @@ -50,7 +49,6 @@ impl Backend { uds_path, otlp_endpoint, otlp_service_name, - pooling_mode, )?; let padded_model = backend.is_padded(); let max_batch_size = backend.max_batch_size(); @@ -140,7 +138,6 @@ fn init_backend( uds_path: String, otlp_endpoint: Option, otlp_service_name: String, - pooling_mode: String, ) -> Result, BackendError> { if cfg!(feature = "candle") { #[cfg(feature = "candle")] @@ -161,7 +158,6 @@ fn init_backend( uds_path, otlp_endpoint, otlp_service_name, - pooling_mode, ) }) .join() diff --git a/docs/source/en/local_amd_gpu.md b/docs/source/en/local_amd_gpu.md index 8dc8e2de..2cfab5ac 100644 --- a/docs/source/en/local_amd_gpu.md +++ b/docs/source/en/local_amd_gpu.md @@ -37,4 +37,4 @@ and curl 127.0.0.1:80/embed \ -X POST -d '{"inputs":"What is Deep Learning?"}' \ -H 'Content-Type: application/json' -``` \ No newline at end of file +``` diff --git a/router/src/lib.rs b/router/src/lib.rs index 14f1dfb3..d2023515 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -105,7 +105,7 @@ pub async fn run( serde_json::from_str(&config).context("Failed to parse `config.json`")?; // Set model type from config - let (backend_model_type, inferred_pooling) = get_backend_model_type(&config, &model_root, &pooling)?; + let backend_model_type = get_backend_model_type(&config, &model_root, pooling)?; // Info model type let model_type = match &backend_model_type { @@ -191,11 +191,6 @@ pub async fn run( } }); - let pooling_str = match inferred_pooling { - Some(pool) => pool.to_string(), - None => "none".to_string(), - }; - // Create backend tracing::info!("Starting model backend"); let backend = text_embeddings_backend::Backend::new( @@ -205,7 +200,6 @@ pub async fn run( uds_path.unwrap_or("/tmp/text-embeddings-inference-server".to_string()), otlp_endpoint.clone(), otlp_service_name.clone(), - pooling_str, ) .context("Could not create backend")?; backend @@ -312,24 +306,24 @@ pub async fn run( fn get_backend_model_type( config: &ModelConfig, model_root: &Path, - pooling: &Option, -) -> Result<(text_embeddings_backend::ModelType, Option)> { + pooling: Option, +) -> Result { for arch in &config.architectures { - if Some(text_embeddings_backend::Pool::Splade) == *pooling && arch.ends_with("MaskedLM") { - return Ok((text_embeddings_backend::ModelType::Embedding( + if Some(text_embeddings_backend::Pool::Splade) == pooling && arch.ends_with("MaskedLM") { + return Ok(text_embeddings_backend::ModelType::Embedding( text_embeddings_backend::Pool::Splade, - ), Some(text_embeddings_backend::Pool::Splade))); + )); } else if arch.ends_with("Classification") { if pooling.is_some() { tracing::warn!( "`--pooling` arg is set but model is a classifier. Ignoring `--pooling` arg." ); } - return Ok((text_embeddings_backend::ModelType::Classifier, None)); + return Ok(text_embeddings_backend::ModelType::Classifier); } } - if Some(text_embeddings_backend::Pool::Splade) == *pooling { + if Some(text_embeddings_backend::Pool::Splade) == pooling { return Err(anyhow!( "Splade pooling is not supported: model is not a ForMaskedLM model" )); @@ -337,7 +331,7 @@ fn get_backend_model_type( // Set pooling let pool = match pooling { - Some(pool) => pool.clone(), + Some(pool) => pool, None => { // Load pooling config let config_path = model_root.join("1_Pooling/config.json"); @@ -353,7 +347,7 @@ fn get_backend_model_type( } } }; - Ok((text_embeddings_backend::ModelType::Embedding(pool.clone()), Some(pool))) + Ok(text_embeddings_backend::ModelType::Embedding(pool)) } #[derive(Debug, Deserialize)] diff --git a/tests/README.md b/tests/README.md index c4ff5d0b..cfbf805c 100644 --- a/tests/README.md +++ b/tests/README.md @@ -31,4 +31,4 @@ Restart server with `USE_FLASH_ATTENTION=0`, and ``` python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 1 python collect.py --model-id sentence-transformers/all-MiniLM-L6-v2 --n_inp 3 -``` \ No newline at end of file +``` diff --git a/tests/collect.py b/tests/collect.py index 313c0871..640f854c 100644 --- a/tests/collect.py +++ b/tests/collect.py @@ -34,4 +34,4 @@ save_path = f"./assets/{args.model_id.replace('/', '-')}_inp{args.n_inp}{postfix}.pt" print(f"Saving embedding of shape {embedding.shape} to {save_path}") -torch.save(embedding, save_path) \ No newline at end of file +torch.save(embedding, save_path) diff --git a/tests/conftest.py b/tests/conftest.py index 6d8ed997..efdd6fc2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,7 +26,7 @@ def __init__(self, process, port: int): def _inner_health(self) -> bool: return self.process.poll() is None - + def health(self, timeout: int = 60): assert timeout > 0 for _ in range(timeout): @@ -109,5 +109,5 @@ def local_launcher( if not use_flash_attention: del env["USE_FLASH_ATTENTION"] - - return local_launcher \ No newline at end of file + + return local_launcher diff --git a/tests/requirements.txt b/tests/requirements.txt index b1ee0f58..74d3b667 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,3 +1,3 @@ pytest pytest-asyncio -aiohttp \ No newline at end of file +aiohttp diff --git a/tests/test_default_model.py b/tests/test_default_model.py index 595fe6bf..68499928 100644 --- a/tests/test_default_model.py +++ b/tests/test_default_model.py @@ -25,4 +25,4 @@ async def test_single_query(default_model): embedding = torch.Tensor(json.loads(response.text)) reference_embedding = torch.load("./tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1_no_flash.pt") - assert torch.allclose(embedding, reference_embedding, atol=1e-3, rtol=1e-3) \ No newline at end of file + assert torch.allclose(embedding, reference_embedding, atol=1e-3, rtol=1e-3) diff --git a/tests/test_flash_bert.py b/tests/test_flash_bert.py index 3c3fde1c..04085522 100644 --- a/tests/test_flash_bert.py +++ b/tests/test_flash_bert.py @@ -25,4 +25,4 @@ async def test_single_query(default_model): embedding = torch.Tensor(json.loads(response.text)) reference_embedding = torch.load("./tests/assets/sentence-transformers-all-MiniLM-L6-v2_inp1.pt") - assert torch.allclose(embedding, reference_embedding, atol=1e-3, rtol=1e-3) \ No newline at end of file + assert torch.allclose(embedding, reference_embedding, atol=1e-3, rtol=1e-3) From 20bda42c93aecc3bb7e14c7d7faeca34a560290c Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Fri, 21 Jun 2024 14:15:53 +0000 Subject: [PATCH 13/17] add file back --- backends/python/src/lib.rs | 125 +++++++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/backends/python/src/lib.rs b/backends/python/src/lib.rs index 8b137891..142547bc 100644 --- a/backends/python/src/lib.rs +++ b/backends/python/src/lib.rs @@ -1 +1,126 @@ +mod logging; +mod management; +use backend_grpc_client::Client; +use nohash_hasher::BuildNoHashHasher; +use std::collections::HashMap; +use text_embeddings_backend_core::{ + Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions, +}; +use tokio::runtime::Runtime; + +pub struct PythonBackend { + _backend_process: management::BackendProcess, + tokio_runtime: Runtime, + backend_client: Client, +} + +impl PythonBackend { + pub fn new( + model_path: String, + dtype: String, + model_type: ModelType, + uds_path: String, + otlp_endpoint: Option, + otlp_service_name: String, + ) -> Result { + let model_type_clone = model_type.clone(); + + match model_type { + ModelType::Classifier => { + return Err(BackendError::Start( + "`classifier` model type is not supported".to_string(), + )) + } + ModelType::Embedding(pool) => { + if pool != Pool::Cls && pool != Pool::Mean { + return Err(BackendError::Start(format!("{pool:?} is not supported in the TEI Python backend. Please open an issue."))); + } + pool + } + }; + + let pool_string = match &model_type_clone { + ModelType::Classifier => &Pool::Cls, + ModelType::Embedding(pool) => pool, + } + .to_string(); + + let backend_process = management::BackendProcess::new( + model_path, + dtype, + &uds_path, + otlp_endpoint, + otlp_service_name, + pool_string, + )?; + let tokio_runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .map_err(|err| BackendError::Start(format!("Could not start Tokio runtime: {err}")))?; + + let backend_client = tokio_runtime + .block_on(Client::connect_uds(uds_path)) + .map_err(|err| { + BackendError::Start(format!("Could not connect to backend process: {err}")) + })?; + + Ok(Self { + _backend_process: backend_process, + tokio_runtime, + backend_client, + }) + } +} + +impl Backend for PythonBackend { + fn health(&self) -> Result<(), BackendError> { + if self + .tokio_runtime + .block_on(self.backend_client.clone().health()) + .is_err() + { + return Err(BackendError::Unhealthy); + } + Ok(()) + } + + fn is_padded(&self) -> bool { + false + } + + fn embed(&self, batch: Batch) -> Result { + if !batch.raw_indices.is_empty() { + return Err(BackendError::Inference( + "raw embeddings are not supported for the Python backend.".to_string(), + )); + } + let batch_size = batch.len(); + + let results = self + .tokio_runtime + .block_on(self.backend_client.clone().embed( + batch.input_ids, + batch.token_type_ids, + batch.position_ids, + batch.cumulative_seq_lengths, + batch.max_length, + )) + .map_err(|err| BackendError::Inference(err.to_string()))?; + let pooled_embeddings: Vec> = results.into_iter().map(|r| r.values).collect(); + + let mut embeddings = + HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); + for (i, e) in pooled_embeddings.into_iter().enumerate() { + embeddings.insert(i, Embedding::Pooled(e)); + } + + Ok(embeddings) + } + + fn predict(&self, _batch: Batch) -> Result { + Err(BackendError::Inference( + "`predict` is not implemented".to_string(), + )) + } +} From 2d2325ce8e7e8be897d6bd054c57ca8c6151ccd7 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 24 Jun 2024 08:21:32 +0000 Subject: [PATCH 14/17] uncomment flash error captureg --- .../server/text_embeddings_server/models/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/backends/python/server/text_embeddings_server/models/__init__.py b/backends/python/server/text_embeddings_server/models/__init__.py index c606efc9..fbd2348b 100644 --- a/backends/python/server/text_embeddings_server/models/__init__.py +++ b/backends/python/server/text_embeddings_server/models/__init__.py @@ -15,11 +15,11 @@ torch.set_grad_enabled(False) FLASH_ATTENTION = True -# try: -from text_embeddings_server.models.flash_bert import FlashBert -# except ImportError as e: -# logger.warning(f"Could not import Flash Attention enabled models: {e}") -# FLASH_ATTENTION = False +try: + from text_embeddings_server.models.flash_bert import FlashBert +except ImportError as e: + logger.warning(f"Could not import Flash Attention enabled models: {e}") + FLASH_ATTENTION = False if FLASH_ATTENTION: __all__.append(FlashBert) From 839a4455e7c13570a8ee4f216663eb1937fab109 Mon Sep 17 00:00:00 2001 From: Felix Marty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 24 Jun 2024 08:23:51 +0000 Subject: [PATCH 15/17] fix rocm workflow --- .github/workflows/build_rocm.yaml | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build_rocm.yaml b/.github/workflows/build_rocm.yaml index 5971ed78..f9c2f620 100644 --- a/.github/workflows/build_rocm.yaml +++ b/.github/workflows/build_rocm.yaml @@ -36,22 +36,30 @@ steps: - name: Checkout repository uses: actions/checkout@v3 + + - name: Tailscale + uses: huggingface/tailscale-action@v1 + with: + authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Initialize Docker Buildx uses: docker/setup-buildx-action@v2.0.0 with: install: true + config-inline: | + [registry."docker.io"] + mirrors = ["registry.github-runners.huggingface.tech"] + - name: Configure sccache uses: actions/github-script@v6 with: script: | core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + - name: Inject slug/short variables uses: rlespinasse/github-slug-action@v4.4.1 - - name: Tailscale - uses: huggingface/tailscale-action@v1 - with: - authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Login to GitHub Container Registry if: github.event_name != 'pull_request' uses: docker/login-action@v2 @@ -59,12 +67,14 @@ registry: ghcr.io username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} + - name: Login to internal Container Registry uses: docker/login-action@v2.1.0 with: username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }} password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }} registry: registry.internal.huggingface.tech + - name: Extract metadata (tags, labels) for Docker id: meta-rocm uses: docker/metadata-action@v4.3.0 From 80259c9ee29ecb1fe92482f4552076c617c0fd6e Mon Sep 17 00:00:00 2001 From: Nicholas Broad Date: Tue, 17 Sep 2024 08:28:31 -0700 Subject: [PATCH 16/17] feat(python): add cls and mean pooling (#402) --- .../text_embeddings_server/models/default_model.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/backends/python/server/text_embeddings_server/models/default_model.py b/backends/python/server/text_embeddings_server/models/default_model.py index 17ad4589..3c41b6c3 100644 --- a/backends/python/server/text_embeddings_server/models/default_model.py +++ b/backends/python/server/text_embeddings_server/models/default_model.py @@ -43,7 +43,14 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]: kwargs["position_ids"] = batch.position_ids output = self.model(**kwargs) - embedding = output[0][:, 0] + + if self.pooling_mode == "cls": + embedding = output[0][:, 0] + elif self.pooling_mode == "mean": + embedding = output[0].mean(dim=1) + else: + raise NotImplementedError(f"Pooling {self.pooling_mode} is not implemented in the python backend") + cpu_results = embedding.view(-1).tolist() return [ From cef41a4a4ccdc82b1bb8a3d6cc6756e343564edc Mon Sep 17 00:00:00 2001 From: Mohit Sharma Date: Tue, 5 Nov 2024 20:16:29 +0530 Subject: [PATCH 17/17] ROCM build fixes (#403) Co-authored-by: root --- .github/workflows/build_rocm.yaml | 4 ++-- Dockerfile-rocm | 15 ++++++++------- docs/source/en/local_amd_gpu.md | 5 +++-- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/.github/workflows/build_rocm.yaml b/.github/workflows/build_rocm.yaml index f9c2f620..8ecfa1e5 100644 --- a/.github/workflows/build_rocm.yaml +++ b/.github/workflows/build_rocm.yaml @@ -95,7 +95,7 @@ uses: docker/build-push-action@v4 with: context: . - file: Dockerfile-cuda + file: Dockerfile-rocm push: ${{ github.event_name != 'pull_request' }} platforms: 'linux/amd64' build-args: | @@ -130,7 +130,7 @@ with: context: . target: grpc - file: Dockerfile-cuda + file: Dockerfile-rocm push: ${{ github.event_name != 'pull_request' }} platforms: 'linux/amd64' build-args: | diff --git a/Dockerfile-rocm b/Dockerfile-rocm index 152fa0a0..4ac343a4 100644 --- a/Dockerfile-rocm +++ b/Dockerfile-rocm @@ -81,8 +81,9 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins hipsparse-dev \ hipblas-dev \ hipblaslt-dev \ - rocblas-dev \ hiprand-dev \ + hipsolver-dev \ + rocblas-dev \ rocrand-dev \ && rm -rf /var/lib/apt/lists/* @@ -105,17 +106,17 @@ RUN chmod +x ~/mambaforge.sh && \ # Install flash-attention, torch dependencies RUN pip install numpy einops ninja --no-cache-dir -RUN pip install torch --index-url https://download.pytorch.org/whl/rocm6.0 - -ARG DEFAULT_USE_FLASH_ATTENTION=True -COPY backends/python/Makefile-flash-att-v2 Makefile-flash-att-v2 -RUN make -f Makefile-flash-att-v2 install-flash-attention-v2-rocm - # Install python backend COPY backends/python/server /tei_backends/python/server COPY backends/proto tei_backends/proto RUN make -C /tei_backends/python/server install +RUN pip install --force-reinstall torch==$PYTORCH_VERSION --index-url https://download.pytorch.org/whl/rocm6.0 + +ARG DEFAULT_USE_FLASH_ATTENTION=True +COPY backends/python/Makefile-flash-att-v2 Makefile-flash-att-v2 +RUN make -f Makefile-flash-att-v2 install-flash-attention-v2-rocm + ENV HUGGINGFACE_HUB_CACHE=/data \ PORT=80 \ USE_FLASH_ATTENTION=$DEFAULT_USE_FLASH_ATTENTION diff --git a/docs/source/en/local_amd_gpu.md b/docs/source/en/local_amd_gpu.md index 2cfab5ac..7a2bfb73 100644 --- a/docs/source/en/local_amd_gpu.md +++ b/docs/source/en/local_amd_gpu.md @@ -18,16 +18,17 @@ rendered properly in your Markdown viewer. Text-Embeddings-Inference supports the [AMD GPUs officially supporting ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html), including AMD Instinct MI210, MI250, MI300 and some of the AMD Radeon series GPUs. -To leverage AMD GPUs, Text-Embeddings-Inference relies on its Python backend, and not on the [candle](https://github.com/huggingface/candle) backend that is used for CPU, Nvidia GPUs and Metal. The support in the python backend is more limited (Bert embeddings) but easily extendible. We welcome contributions to extend the supported models. +To leverage AMD GPUs, Text-Embeddings-Inference relies on its Python backend, and not on the [candle](https://github.com/huggingface/candle) backend that is used for CPU, Nvidia GPUs and Metal. The support in the python backend is more limited (Bert embeddings) but easily extensible. We welcome contributions to extend the supported models. ## Usage through docker Using docker is the recommended approach. ```bash +DOCKER_TAG=rocm-xxx # Specify the tag of the docker image to use docker run --rm -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --net host \ --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 32g \ - ghcr.io/huggingface/text-embeddings-inference:rocm-1.2.4 \ + ghcr.io/huggingface/text-embeddings-inference:$DOCKER_TAG \ --model-id sentence-transformers/all-MiniLM-L6-v2 ```