Skip to content

feat: implement the redis features of subscribe, psubscribe, spublish #4746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions core/stores/redis/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@
RedisNode interface {
red.Cmdable
red.BitMapCmdable

Subscribe(ctx context.Context, channels ...string) *red.PubSub
PSubscribe(ctx context.Context, patterns ...string) *red.PubSub
}

// GeoLocation is used with GeoAdd to add geospatial location.
Expand Down Expand Up @@ -1223,10 +1226,14 @@
return err
}

// Publish publishes a message to a specific channel.
// Returns the number of subscribers that received the message or an error if it fails.
func (s *Redis) Publish(channel string, message interface{}) (int64, error) {
return s.PublishCtx(context.Background(), channel, message)
}

// PublishCtx publishes a message to a channel with context control.
// Returns the number of subscribers that received the message or an error if it fails.
func (s *Redis) PublishCtx(ctx context.Context, channel string, message interface{}) (int64, error) {
conn, err := getRedis(s)
if err != nil {
Expand All @@ -1235,6 +1242,54 @@
return conn.Publish(ctx, channel, message).Result()
}

// SPublish publishes a message to a specific shard channel.
// Returns the number of subscribers that received the message or an error if it fails.
func (s *Redis) SPublish(channel string, message interface{}) (int64, error) {
return s.SPublishCtx(context.Background(), channel, message)
}

// SPublishCtx publishes a message to a shard channel with context control.
// Returns the number of subscribers that received the message or an error if it fails.
func (s *Redis) SPublishCtx(ctx context.Context, channel string, message interface{}) (int64, error) {
conn, err := getRedis(s)
if err != nil {
return 0, err
}
return conn.SPublish(ctx, channel, message).Result()

Check warning on line 1258 in core/stores/redis/redis.go

View check run for this annotation

Codecov / codecov/patch

core/stores/redis/redis.go#L1258

Added line #L1258 was not covered by tests
}

// Subscribe subscribes to one or more specific channels.
// Returns a PubSub object to receive messages, or an error if it fails.
func (s *Redis) Subscribe(channels ...string) (*red.PubSub, error) {
return s.SubscribeCtx(context.Background(), channels...)
}

// SubscribeCtx subscribes to one or more specific channels with context control.
// Returns a PubSub object to receive messages, or an error if it fails.
func (s *Redis) SubscribeCtx(ctx context.Context, channels ...string) (*red.PubSub, error) {
conn, err := getRedis(s)
if err != nil {
return nil, err
}
return conn.Subscribe(ctx, channels...), nil
}

// PSubscribe subscribes to one or more channel patterns (e.g., "channel.*").
// Returns a PubSub object to receive messages matching the patterns, or an error if it fails.
func (s *Redis) PSubscribe(channels ...string) (*red.PubSub, error) {
return s.PSubscribeCtx(context.Background(), channels...)
}

// PSubscribeCtx subscribes to one or more channel patterns with context control.
// Returns a PubSub object to receive messages matching the patterns, or an error if it fails.
func (s *Redis) PSubscribeCtx(ctx context.Context, channels ...string) (*red.PubSub, error) {
conn, err := getRedis(s)
if err != nil {
return nil, err
}
return conn.PSubscribe(ctx, channels...), nil
}

// Rpop is the implementation of redis rpop command.
func (s *Redis) Rpop(key string) (string, error) {
return s.RpopCtx(context.Background(), key)
Expand Down
82 changes: 82 additions & 0 deletions core/stores/redis/redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"io"
"strconv"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -2118,6 +2119,87 @@ func TestRedisPublish(t *testing.T) {
})
}

func TestRedisSPublish(t *testing.T) {
runOnRedis(t, func(client *Redis) {
_, err := newRedis(client.Addr, badType()).SPublish("Test", "message")
assert.NotNil(t, err)
})
}

func TestRedisSubscribe(t *testing.T) {
runOnRedis(t, func(client *Redis) {
_, err := newRedis(client.Addr, badType()).Subscribe("Test")
assert.NotNil(t, err)

pubSub, err := client.Subscribe("Test")
defer pubSub.Close()
assert.Nil(t, err)

messages := []string{"message1", "message2", "message3"}
for _, msg := range messages {
_, err := client.Publish("Test", msg)
assert.Nil(t, err)
}

ch := pubSub.Channel()
receivedMessages := make([]string, 0, len(messages))
for i := 0; i < len(messages); i++ {
select {
case msg := <-ch:
receivedMessages = append(receivedMessages, msg.Payload)
case <-time.After(time.Second):
t.Error("Timeout waiting for message")
}
}
assert.Equal(t, messages, receivedMessages)
})
}

func TestRedisPSubscribe(t *testing.T) {
runOnRedis(t, func(client *Redis) {
pattern := "Test.*"
_, err := newRedis(client.Addr, badType()).PSubscribe(pattern)
assert.NotNil(t, err)

pubSubs := make([]*red.PubSub, 3)
receivedMessages := make([][]string, 3)
var mu sync.Mutex
for i := 0; i < 3; i++ {
pubSub, err := client.PSubscribe(pattern)
assert.Nil(t, err)
pubSubs[i] = pubSub
receivedMessages[i] = make([]string, 0)
go func(idx int) {
ch := pubSub.Channel()
for msg := range ch {
mu.Lock()
receivedMessages[idx] = append(receivedMessages[idx], msg.Payload)
mu.Unlock()
}
}(i)
}

// 確保在測試結束時關閉所有訂閱
defer func() {
for _, pubSub := range pubSubs {
pubSub.Close()
}
}()
messages := []string{"message1", "message2", "message3"}
channels := []string{"Test.1", "Test.2", "Test.3"}
for i, msg := range messages {
_, err := client.Publish(channels[i], msg)
assert.Nil(t, err)
}
time.Sleep(10 * time.Millisecond)
for i := 0; i < 3; i++ {
mu.Lock()
assert.ElementsMatch(t, messages, receivedMessages[i])
mu.Unlock()
}
})
}

func TestRedisRPopLPush(t *testing.T) {
runOnRedis(t, func(client *Redis) {
_, err := newRedis(client.Addr, badType()).RPopLPush("Source", "Destination")
Expand Down
Loading