Skip to content

Commit b1ea931

Browse files
authored
feat(silero): add Silero-vad backend (#4204)
* feat(vad): add silero-vad backend (WIP) Signed-off-by: Ettore Di Giacinto <[email protected]> * feat(vad): add API endpoint Signed-off-by: Ettore Di Giacinto <[email protected]> * fix(vad): correctly place the onnxruntime libs Signed-off-by: Ettore Di Giacinto <[email protected]> * chore(vad): hook silero-vad to binary and container builds Signed-off-by: Ettore Di Giacinto <[email protected]> * feat(gRPC): register VAD Server Signed-off-by: Ettore Di Giacinto <[email protected]> * fix(Makefile): consume ONNX_OS consistently Signed-off-by: Ettore Di Giacinto <[email protected]> * fix(Makefile): handle macOS Signed-off-by: Ettore Di Giacinto <[email protected]> --------- Signed-off-by: Ettore Di Giacinto <[email protected]> Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 9892d7d commit b1ea931

File tree

15 files changed

+255
-1
lines changed

15 files changed

+255
-1
lines changed

Makefile

+39
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ STABLEDIFFUSION_VERSION?=4a3cd6aeae6f66ee57eae9a0075f8c58c3a6a38f
3434
TINYDREAM_REPO?=https://github.com/M0Rf30/go-tiny-dream
3535
TINYDREAM_VERSION?=c04fa463ace9d9a6464313aa5f9cd0f953b6c057
3636

37+
ONNX_VERSION?=1.20.0
38+
ONNX_ARCH?=x64
39+
ONNX_OS?=linux
40+
3741
export BUILD_TYPE?=
3842
export STABLE_BUILD_TYPE?=$(BUILD_TYPE)
3943
export CMAKE_ARGS?=
@@ -89,7 +93,20 @@ ifeq ($(NATIVE),false)
8993
CMAKE_ARGS+=-DGGML_NATIVE=OFF
9094
endif
9195

96+
# Detect if we are running on arm64
97+
ifneq (,$(findstring aarch64,$(shell uname -m)))
98+
ONNX_ARCH=aarch64
99+
endif
100+
92101
ifeq ($(OS),Darwin)
102+
ONNX_OS=osx
103+
ifneq (,$(findstring aarch64,$(shell uname -m)))
104+
ONNX_ARCH=arm64
105+
else ifneq (,$(findstring arm64,$(shell uname -m)))
106+
ONNX_ARCH=arm64
107+
else
108+
ONNX_ARCH=x86_64
109+
endif
93110

94111
ifeq ($(OSX_SIGNING_IDENTITY),)
95112
OSX_SIGNING_IDENTITY := $(shell security find-identity -v -p codesigning | grep '"' | head -n 1 | sed -E 's/.*"(.*)"/\1/')
@@ -195,6 +212,7 @@ ALL_GRPC_BACKENDS+=backend-assets/util/llama-cpp-rpc-server
195212
ALL_GRPC_BACKENDS+=backend-assets/grpc/rwkv
196213
ALL_GRPC_BACKENDS+=backend-assets/grpc/whisper
197214
ALL_GRPC_BACKENDS+=backend-assets/grpc/local-store
215+
ALL_GRPC_BACKENDS+=backend-assets/grpc/silero-vad
198216
ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC)
199217
# Use filter-out to remove the specified backends
200218
ALL_GRPC_BACKENDS := $(filter-out $(SKIP_GRPC_BACKEND),$(ALL_GRPC_BACKENDS))
@@ -281,6 +299,20 @@ sources/go-stable-diffusion:
281299
sources/go-stable-diffusion/libstablediffusion.a: sources/go-stable-diffusion
282300
CPATH="$(CPATH):/usr/include/opencv4" $(MAKE) -C sources/go-stable-diffusion libstablediffusion.a
283301

302+
sources/onnxruntime:
303+
mkdir -p sources/onnxruntime
304+
curl -L https://github.com/microsoft/onnxruntime/releases/download/v$(ONNX_VERSION)/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz -o sources/onnxruntime/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz
305+
cd sources/onnxruntime && tar -xvf onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz && rm onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz
306+
cd sources/onnxruntime && mv onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION)/* ./
307+
308+
backend-assets/lib/libonnxruntime.so.1: backend-assets/lib sources/onnxruntime
309+
cp -rfv sources/onnxruntime/lib/* backend-assets/lib/
310+
ifeq ($(OS),Darwin)
311+
mv backend-assets/lib/libonnxruntime.$(ONNX_VERSION).dylib backend-assets/lib/libonnxruntime.dylib
312+
else
313+
mv backend-assets/lib/libonnxruntime.so.$(ONNX_VERSION) backend-assets/lib/libonnxruntime.so.1
314+
endif
315+
284316
## tiny-dream
285317
sources/go-tiny-dream:
286318
mkdir -p sources/go-tiny-dream
@@ -837,6 +869,13 @@ ifneq ($(UPX),)
837869
$(UPX) backend-assets/grpc/stablediffusion
838870
endif
839871

872+
backend-assets/grpc/silero-vad: backend-assets/grpc backend-assets/lib/libonnxruntime.so.1
873+
CGO_LDFLAGS="$(CGO_LDFLAGS)" CPATH="$(CPATH):$(CURDIR)/sources/onnxruntime/include/" LIBRARY_PATH=$(CURDIR)/backend-assets/lib \
874+
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/silero-vad ./backend/go/vad/silero
875+
ifneq ($(UPX),)
876+
$(UPX) backend-assets/grpc/silero-vad
877+
endif
878+
840879
backend-assets/grpc/tinydream: sources/go-tiny-dream sources/go-tiny-dream/libtinydream.a backend-assets/grpc
841880
CGO_LDFLAGS="$(CGO_LDFLAGS)" LIBRARY_PATH=$(CURDIR)/go-tiny-dream \
842881
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/tinydream ./backend/go/image/tinydream

backend/backend.proto

+15
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ service Backend {
2828
rpc Rerank(RerankRequest) returns (RerankResult) {}
2929

3030
rpc GetMetrics(MetricsRequest) returns (MetricsResponse);
31+
32+
rpc VAD(VADRequest) returns (VADResponse) {}
3133
}
3234

3335
// Define the empty request
@@ -293,6 +295,19 @@ message TTSRequest {
293295
optional string language = 5;
294296
}
295297

298+
message VADRequest {
299+
repeated float audio = 1;
300+
}
301+
302+
message VADSegment {
303+
float start = 1;
304+
float end = 2;
305+
}
306+
307+
message VADResponse {
308+
repeated VADSegment segments = 1;
309+
}
310+
296311
message SoundGenerationRequest {
297312
string text = 1;
298313
string model = 2;

backend/go/vad/silero/main.go

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package main
2+
3+
// Note: this is started internally by LocalAI and a server is allocated for each model
4+
5+
import (
6+
"flag"
7+
8+
grpc "github.com/mudler/LocalAI/pkg/grpc"
9+
)
10+
11+
var (
12+
addr = flag.String("addr", "localhost:50051", "the address to connect to")
13+
)
14+
15+
func main() {
16+
flag.Parse()
17+
18+
if err := grpc.StartServer(*addr, &VAD{}); err != nil {
19+
panic(err)
20+
}
21+
}

backend/go/vad/silero/vad.go

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package main
2+
3+
// This is a wrapper to statisfy the GRPC service interface
4+
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
5+
import (
6+
"fmt"
7+
8+
"github.com/mudler/LocalAI/pkg/grpc/base"
9+
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
10+
"github.com/streamer45/silero-vad-go/speech"
11+
)
12+
13+
type VAD struct {
14+
base.SingleThread
15+
detector *speech.Detector
16+
}
17+
18+
func (vad *VAD) Load(opts *pb.ModelOptions) error {
19+
v, err := speech.NewDetector(speech.DetectorConfig{
20+
ModelPath: opts.ModelFile,
21+
SampleRate: 16000,
22+
//WindowSize: 1024,
23+
Threshold: 0.5,
24+
MinSilenceDurationMs: 0,
25+
SpeechPadMs: 0,
26+
})
27+
if err != nil {
28+
return fmt.Errorf("create silero detector: %w", err)
29+
}
30+
31+
vad.detector = v
32+
return err
33+
}
34+
35+
func (vad *VAD) VAD(req *pb.VADRequest) (pb.VADResponse, error) {
36+
audio := req.Audio
37+
38+
segments, err := vad.detector.Detect(audio)
39+
if err != nil {
40+
return pb.VADResponse{}, fmt.Errorf("detect: %w", err)
41+
}
42+
43+
vadSegments := []*pb.VADSegment{}
44+
for _, s := range segments {
45+
vadSegments = append(vadSegments, &pb.VADSegment{
46+
Start: float32(s.SpeechStartAt),
47+
End: float32(s.SpeechEndAt),
48+
})
49+
}
50+
51+
return pb.VADResponse{
52+
Segments: vadSegments,
53+
}, nil
54+
}

core/http/endpoints/localai/vad.go

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package localai
2+
3+
import (
4+
"github.com/gofiber/fiber/v2"
5+
"github.com/mudler/LocalAI/core/backend"
6+
"github.com/mudler/LocalAI/core/config"
7+
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
8+
"github.com/mudler/LocalAI/core/schema"
9+
"github.com/mudler/LocalAI/pkg/grpc/proto"
10+
"github.com/mudler/LocalAI/pkg/model"
11+
"github.com/rs/zerolog/log"
12+
)
13+
14+
// VADEndpoint is Voice-Activation-Detection endpoint
15+
// @Summary Detect voice fragments in an audio stream
16+
// @Accept json
17+
// @Param request body schema.VADRequest true "query params"
18+
// @Success 200 {object} proto.VADResponse "Response"
19+
// @Router /vad [post]
20+
func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
21+
return func(c *fiber.Ctx) error {
22+
23+
input := new(schema.VADRequest)
24+
25+
// Get input data from the request body
26+
if err := c.BodyParser(input); err != nil {
27+
return err
28+
}
29+
30+
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false)
31+
if err != nil {
32+
modelFile = input.Model
33+
log.Warn().Msgf("Model not found in context: %s", input.Model)
34+
}
35+
36+
cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath,
37+
config.LoadOptionDebug(appConfig.Debug),
38+
config.LoadOptionThreads(appConfig.Threads),
39+
config.LoadOptionContextSize(appConfig.ContextSize),
40+
config.LoadOptionF16(appConfig.F16),
41+
)
42+
43+
if err != nil {
44+
log.Err(err)
45+
modelFile = input.Model
46+
log.Warn().Msgf("Model not found in context: %s", input.Model)
47+
} else {
48+
modelFile = cfg.Model
49+
}
50+
log.Debug().Msgf("Request for model: %s", modelFile)
51+
52+
opts := backend.ModelOptions(*cfg, appConfig, model.WithBackendString(cfg.Backend), model.WithModel(modelFile))
53+
54+
vadModel, err := ml.Load(opts...)
55+
if err != nil {
56+
return err
57+
}
58+
req := proto.VADRequest{
59+
Audio: input.Audio,
60+
}
61+
resp, err := vadModel.VAD(c.Context(), &req)
62+
if err != nil {
63+
return err
64+
}
65+
66+
return c.JSON(resp)
67+
}
68+
}

core/http/routes/localai.go

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ func RegisterLocalAIRoutes(app *fiber.App,
3434
}
3535

3636
app.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig))
37+
app.Post("/vad", localai.VADEndpoint(cl, ml, appConfig))
3738

3839
// Stores
3940
sl := model.NewModelLoader("")

core/schema/localai.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,16 @@ type TTSRequest struct {
3030
Input string `json:"input" yaml:"input"` // text input
3131
Voice string `json:"voice" yaml:"voice"` // voice audio file or speaker id
3232
Backend string `json:"backend" yaml:"backend"`
33-
Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model
33+
Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model
3434
Format string `json:"response_format,omitempty" yaml:"response_format,omitempty"` // (optional) output format
3535
}
3636

37+
// @Description VAD request body
38+
type VADRequest struct {
39+
Model string `json:"model" yaml:"model"` // model name or full path
40+
Audio []float32 `json:"audio" yaml:"audio"` // model name or full path
41+
}
42+
3743
type StoresSet struct {
3844
Store string `json:"store,omitempty" yaml:"store,omitempty"`
3945

go.mod

+3
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ require (
8686
github.com/pion/turn/v2 v2.1.6 // indirect
8787
github.com/pion/webrtc/v3 v3.3.0 // indirect
8888
github.com/shirou/gopsutil/v4 v4.24.7 // indirect
89+
github.com/streamer45/silero-vad-go v0.2.1 // indirect
90+
github.com/urfave/cli/v2 v2.27.4 // indirect
91+
github.com/valyala/fasttemplate v1.2.2 // indirect
8992
github.com/wlynxg/anet v0.0.4 // indirect
9093
go.uber.org/mock v0.4.0 // indirect
9194
)

go.sum

+5
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,11 @@ github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:Udh
674674
github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA=
675675
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
676676
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
677+
github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
678+
github.com/spf13/cast v1.5.0 h1:rj3WzYc11XZaIZMPKmwP96zkFEnnAmV8s6XbB2aY32w=
679+
github.com/spf13/cast v1.5.0/go.mod h1:SpXXQ5YoyJw6s3/6cMTQuxvgRl3PCJiyaX9p6b155UU=
680+
github.com/streamer45/silero-vad-go v0.2.1 h1:Li1/tTC4H/3cyw6q4weX+U8GWwEL3lTekK/nYa1Cvuk=
681+
github.com/streamer45/silero-vad-go v0.2.1/go.mod h1:B+2FXs/5fZ6pzl6unUZYhZqkYdOB+3saBVzjOzdZnUs=
677682
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
678683
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
679684
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

pkg/grpc/backend.go

+2
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,6 @@ type Backend interface {
5353
Rerank(ctx context.Context, in *pb.RerankRequest, opts ...grpc.CallOption) (*pb.RerankResult, error)
5454

5555
GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error)
56+
57+
VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error)
5658
}

pkg/grpc/base/base.go

+4
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ func (llm *Base) StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error)
9292
return pb.StoresFindResult{}, fmt.Errorf("unimplemented")
9393
}
9494

95+
func (llm *Base) VAD(*pb.VADRequest) (pb.VADResponse, error) {
96+
return pb.VADResponse{}, fmt.Errorf("unimplemented")
97+
}
98+
9599
func memoryUsage() *pb.MemoryUsageData {
96100
mud := pb.MemoryUsageData{
97101
Breakdown: make(map[string]uint64),

pkg/grpc/client.go

+18
Original file line numberDiff line numberDiff line change
@@ -392,3 +392,21 @@ func (c *Client) GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opt
392392
client := pb.NewBackendClient(conn)
393393
return client.GetMetrics(ctx, in, opts...)
394394
}
395+
396+
func (c *Client) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error) {
397+
if !c.parallel {
398+
c.opMutex.Lock()
399+
defer c.opMutex.Unlock()
400+
}
401+
c.setBusy(true)
402+
defer c.setBusy(false)
403+
c.wdMark()
404+
defer c.wdUnMark()
405+
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
406+
if err != nil {
407+
return nil, err
408+
}
409+
defer conn.Close()
410+
client := pb.NewBackendClient(conn)
411+
return client.VAD(ctx, in, opts...)
412+
}

pkg/grpc/embed.go

+4
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ func (e *embedBackend) Rerank(ctx context.Context, in *pb.RerankRequest, opts ..
8787
return e.s.Rerank(ctx, in)
8888
}
8989

90+
func (e *embedBackend) VAD(ctx context.Context, in *pb.VADRequest, opts ...grpc.CallOption) (*pb.VADResponse, error) {
91+
return e.s.VAD(ctx, in)
92+
}
93+
9094
func (e *embedBackend) GetTokenMetrics(ctx context.Context, in *pb.MetricsRequest, opts ...grpc.CallOption) (*pb.MetricsResponse, error) {
9195
return e.s.GetMetrics(ctx, in)
9296
}

pkg/grpc/interface.go

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ type LLM interface {
2424
StoresDelete(*pb.StoresDeleteOptions) error
2525
StoresGet(*pb.StoresGetOptions) (pb.StoresGetResult, error)
2626
StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error)
27+
28+
VAD(*pb.VADRequest) (pb.VADResponse, error)
2729
}
2830

2931
func newReply(s string) *pb.Reply {

pkg/grpc/server.go

+12
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,18 @@ func (s *server) StoresFind(ctx context.Context, in *pb.StoresFindOptions) (*pb.
227227
return &res, nil
228228
}
229229

230+
func (s *server) VAD(ctx context.Context, in *pb.VADRequest) (*pb.VADResponse, error) {
231+
if s.llm.Locking() {
232+
s.llm.Lock()
233+
defer s.llm.Unlock()
234+
}
235+
res, err := s.llm.VAD(in)
236+
if err != nil {
237+
return nil, err
238+
}
239+
return &res, nil
240+
}
241+
230242
func StartServer(address string, model LLM) error {
231243
lis, err := net.Listen("tcp", address)
232244
if err != nil {

0 commit comments

Comments
 (0)