diff --git a/core/stores/redis/redis.go b/core/stores/redis/redis.go index 66fcef4e0ad9..7ad33bec0025 100644 --- a/core/stores/redis/redis.go +++ b/core/stores/redis/redis.go @@ -66,6 +66,9 @@ type ( RedisNode interface { red.Cmdable red.BitMapCmdable + + Subscribe(ctx context.Context, channels ...string) *red.PubSub + PSubscribe(ctx context.Context, patterns ...string) *red.PubSub } // GeoLocation is used with GeoAdd to add geospatial location. @@ -1223,10 +1226,14 @@ func (s *Redis) PipelinedCtx(ctx context.Context, fn func(Pipeliner) error) erro return err } +// Publish publishes a message to a specific channel. +// Returns the number of subscribers that received the message or an error if it fails. func (s *Redis) Publish(channel string, message interface{}) (int64, error) { return s.PublishCtx(context.Background(), channel, message) } +// PublishCtx publishes a message to a channel with context control. +// Returns the number of subscribers that received the message or an error if it fails. func (s *Redis) PublishCtx(ctx context.Context, channel string, message interface{}) (int64, error) { conn, err := getRedis(s) if err != nil { @@ -1235,6 +1242,54 @@ func (s *Redis) PublishCtx(ctx context.Context, channel string, message interfac return conn.Publish(ctx, channel, message).Result() } +// SPublish publishes a message to a specific shard channel. +// Returns the number of subscribers that received the message or an error if it fails. +func (s *Redis) SPublish(channel string, message interface{}) (int64, error) { + return s.SPublishCtx(context.Background(), channel, message) +} + +// SPublishCtx publishes a message to a shard channel with context control. +// Returns the number of subscribers that received the message or an error if it fails. +func (s *Redis) SPublishCtx(ctx context.Context, channel string, message interface{}) (int64, error) { + conn, err := getRedis(s) + if err != nil { + return 0, err + } + return conn.SPublish(ctx, channel, message).Result() +} + +// Subscribe subscribes to one or more specific channels. +// Returns a PubSub object to receive messages, or an error if it fails. +func (s *Redis) Subscribe(channels ...string) (*red.PubSub, error) { + return s.SubscribeCtx(context.Background(), channels...) +} + +// SubscribeCtx subscribes to one or more specific channels with context control. +// Returns a PubSub object to receive messages, or an error if it fails. +func (s *Redis) SubscribeCtx(ctx context.Context, channels ...string) (*red.PubSub, error) { + conn, err := getRedis(s) + if err != nil { + return nil, err + } + return conn.Subscribe(ctx, channels...), nil +} + +// PSubscribe subscribes to one or more channel patterns (e.g., "channel.*"). +// Returns a PubSub object to receive messages matching the patterns, or an error if it fails. +func (s *Redis) PSubscribe(channels ...string) (*red.PubSub, error) { + return s.PSubscribeCtx(context.Background(), channels...) +} + +// PSubscribeCtx subscribes to one or more channel patterns with context control. +// Returns a PubSub object to receive messages matching the patterns, or an error if it fails. +func (s *Redis) PSubscribeCtx(ctx context.Context, channels ...string) (*red.PubSub, error) { + conn, err := getRedis(s) + if err != nil { + return nil, err + } + return conn.PSubscribe(ctx, channels...), nil +} + // Rpop is the implementation of redis rpop command. func (s *Redis) Rpop(key string) (string, error) { return s.RpopCtx(context.Background(), key) diff --git a/core/stores/redis/redis_test.go b/core/stores/redis/redis_test.go index b34792ee83c3..c61de5077dc6 100644 --- a/core/stores/redis/redis_test.go +++ b/core/stores/redis/redis_test.go @@ -6,6 +6,7 @@ import ( "errors" "io" "strconv" + "sync" "testing" "time" @@ -2118,6 +2119,87 @@ func TestRedisPublish(t *testing.T) { }) } +func TestRedisSPublish(t *testing.T) { + runOnRedis(t, func(client *Redis) { + _, err := newRedis(client.Addr, badType()).SPublish("Test", "message") + assert.NotNil(t, err) + }) +} + +func TestRedisSubscribe(t *testing.T) { + runOnRedis(t, func(client *Redis) { + _, err := newRedis(client.Addr, badType()).Subscribe("Test") + assert.NotNil(t, err) + + pubSub, err := client.Subscribe("Test") + defer pubSub.Close() + assert.Nil(t, err) + + messages := []string{"message1", "message2", "message3"} + for _, msg := range messages { + _, err := client.Publish("Test", msg) + assert.Nil(t, err) + } + + ch := pubSub.Channel() + receivedMessages := make([]string, 0, len(messages)) + for i := 0; i < len(messages); i++ { + select { + case msg := <-ch: + receivedMessages = append(receivedMessages, msg.Payload) + case <-time.After(time.Second): + t.Error("Timeout waiting for message") + } + } + assert.Equal(t, messages, receivedMessages) + }) +} + +func TestRedisPSubscribe(t *testing.T) { + runOnRedis(t, func(client *Redis) { + pattern := "Test.*" + _, err := newRedis(client.Addr, badType()).PSubscribe(pattern) + assert.NotNil(t, err) + + pubSubs := make([]*red.PubSub, 3) + receivedMessages := make([][]string, 3) + var mu sync.Mutex + for i := 0; i < 3; i++ { + pubSub, err := client.PSubscribe(pattern) + assert.Nil(t, err) + pubSubs[i] = pubSub + receivedMessages[i] = make([]string, 0) + go func(idx int) { + ch := pubSub.Channel() + for msg := range ch { + mu.Lock() + receivedMessages[idx] = append(receivedMessages[idx], msg.Payload) + mu.Unlock() + } + }(i) + } + + // 確保在測試結束時關閉所有訂閱 + defer func() { + for _, pubSub := range pubSubs { + pubSub.Close() + } + }() + messages := []string{"message1", "message2", "message3"} + channels := []string{"Test.1", "Test.2", "Test.3"} + for i, msg := range messages { + _, err := client.Publish(channels[i], msg) + assert.Nil(t, err) + } + time.Sleep(10 * time.Millisecond) + for i := 0; i < 3; i++ { + mu.Lock() + assert.ElementsMatch(t, messages, receivedMessages[i]) + mu.Unlock() + } + }) +} + func TestRedisRPopLPush(t *testing.T) { runOnRedis(t, func(client *Redis) { _, err := newRedis(client.Addr, badType()).RPopLPush("Source", "Destination")