@@ -2,10 +2,13 @@ package glance
2
2
3
3
import (
4
4
"context"
5
+ "encoding/base64"
6
+ "encoding/json"
5
7
"errors"
6
8
"fmt"
7
9
"html"
8
10
"html/template"
11
+ "io"
9
12
"net/http"
10
13
"net/url"
11
14
"strings"
@@ -33,6 +36,17 @@ type redditWidget struct {
33
36
Limit int `yaml:"limit"`
34
37
CollapseAfter int `yaml:"collapse-after"`
35
38
RequestUrlTemplate string `yaml:"request-url-template"`
39
+ RedditAppName string `yaml:"reddit-app-name"`
40
+ RedditClientID string `yaml:"reddit-client-id"`
41
+ RedditClientSecret string `yaml:"reddit-client-secret"`
42
+ redditAccessToken string
43
+ }
44
+
45
+ type redditTokenResponse struct {
46
+ AccessToken string `json:"access_token"`
47
+ TokenType string `json:"token_type"`
48
+ Scope string `json:"scope"`
49
+ ExpiresIn int `json:"expires_in"`
36
50
}
37
51
38
52
func (widget * redditWidget ) initialize () error {
@@ -62,6 +76,10 @@ func (widget *redditWidget) initialize() error {
62
76
}
63
77
}
64
78
79
+ if err := widget .fetchRedditAccessToken (); err != nil {
80
+ return fmt .Errorf ("fetching Reddit API access token: %w" , err )
81
+ }
82
+
65
83
widget .
66
84
withTitle ("r/" + widget .Subreddit ).
67
85
withTitleURL ("https://www.reddit.com/r/" + widget .Subreddit + "/" ).
@@ -87,17 +105,7 @@ func isValidRedditTopPeriod(period string) bool {
87
105
}
88
106
89
107
func (widget * redditWidget ) update (ctx context.Context ) {
90
- // TODO: refactor, use a struct to pass all of these
91
- posts , err := fetchSubredditPosts (
92
- widget .Subreddit ,
93
- widget .SortBy ,
94
- widget .TopPeriod ,
95
- widget .Search ,
96
- widget .CommentsUrlTemplate ,
97
- widget .RequestUrlTemplate ,
98
- widget .Proxy .client ,
99
- widget .ShowFlairs ,
100
- )
108
+ posts , err := widget .fetchSubredditPosts ()
101
109
102
110
if ! widget .canContinueUpdateAfterHandlingErr (err ) {
103
111
return
@@ -163,49 +171,55 @@ func templateRedditCommentsURL(template, subreddit, postId, postPath string) str
163
171
return template
164
172
}
165
173
166
- func fetchSubredditPosts (
167
- subreddit ,
168
- sort ,
169
- topPeriod ,
170
- search ,
171
- commentsUrlTemplate ,
172
- requestUrlTemplate string ,
173
- proxyClient * http.Client ,
174
- showFlairs bool ,
175
- ) (forumPostList , error ) {
176
- query := url.Values {}
177
- var requestUrl string
174
+ func (widget * redditWidget ) fetchSubredditPosts () (forumPostList , error ) {
175
+ var baseURL string
178
176
179
- if search != "" {
180
- query .Set ("q" , search + " subreddit:" + subreddit )
181
- query .Set ("sort" , sort )
177
+ if widget .redditAccessToken != "" {
178
+ baseURL = "https://oauth.reddit.com"
179
+ } else {
180
+ baseURL = "https://www.reddit.com"
182
181
}
183
182
184
- if sort == "top" {
185
- query .Set ("t" , topPeriod )
186
- }
183
+ query := url.Values {}
184
+ var requestURL string
185
+
186
+ if widget .Search != "" {
187
+ query .Set ("q" , widget .Search + " subreddit:" + widget .Subreddit )
188
+ query .Set ("sort" , widget .SortBy )
187
189
188
- if search != "" {
189
- requestUrl = fmt .Sprintf ("https://www.reddit.com/search.json?%s" , query .Encode ())
190
+ requestURL = fmt .Sprintf ("%s/search.json?%s" , baseURL , query .Encode ())
190
191
} else {
191
- requestUrl = fmt .Sprintf ("https://www.reddit.com/r/%s/%s.json?%s" , subreddit , sort , query .Encode ())
192
+ if widget .SortBy == "top" {
193
+ query .Set ("t" , widget .TopPeriod )
194
+ }
195
+
196
+ requestURL = fmt .Sprintf ("%s/r/%s/%s.json?%s" , baseURL , widget .Subreddit , widget .SortBy , query .Encode ())
192
197
}
193
198
194
199
var client requestDoer = defaultHTTPClient
195
200
196
- if requestUrlTemplate != "" {
197
- requestUrl = strings .ReplaceAll (requestUrlTemplate , "{REQUEST-URL}" , requestUrl )
198
- } else if proxyClient != nil {
199
- client = proxyClient
201
+ if widget . RequestUrlTemplate != "" {
202
+ requestURL = strings .ReplaceAll (widget . RequestUrlTemplate , "{REQUEST-URL}" , requestURL )
203
+ } else if widget . Proxy . client != nil {
204
+ client = widget . Proxy . client
200
205
}
201
206
202
- request , err := http .NewRequest ("GET" , requestUrl , nil )
207
+ request , err := http .NewRequest ("GET" , requestURL , nil )
203
208
if err != nil {
204
209
return nil , err
205
210
}
206
211
207
212
// Required to increase rate limit, otherwise Reddit randomly returns 429 even after just 2 requests
208
- setBrowserUserAgentHeader (request )
213
+ if widget .RedditAppName != "" {
214
+ request .Header .Set ("User-Agent" , fmt .Sprintf ("%s/1.0" , widget .RedditAppName ))
215
+ } else {
216
+ setBrowserUserAgentHeader (request )
217
+ }
218
+
219
+ if widget .redditAccessToken != "" {
220
+ request .Header .Set ("Authorization" , fmt .Sprintf ("Bearer %s" , widget .redditAccessToken ))
221
+ }
222
+
209
223
responseJson , err := decodeJsonFromRequest [subredditResponseJson ](client , request )
210
224
if err != nil {
211
225
return nil , err
@@ -226,10 +240,10 @@ func fetchSubredditPosts(
226
240
227
241
var commentsUrl string
228
242
229
- if commentsUrlTemplate == "" {
243
+ if widget . CommentsUrlTemplate == "" {
230
244
commentsUrl = "https://www.reddit.com" + post .Permalink
231
245
} else {
232
- commentsUrl = templateRedditCommentsURL (commentsUrlTemplate , subreddit , post .Id , post .Permalink )
246
+ commentsUrl = templateRedditCommentsURL (widget . CommentsUrlTemplate , widget . Subreddit , post .Id , post .Permalink )
233
247
}
234
248
235
249
forumPost := forumPost {
@@ -249,19 +263,19 @@ func fetchSubredditPosts(
249
263
forumPost .TargetUrl = post .Url
250
264
}
251
265
252
- if showFlairs && post .Flair != "" {
266
+ if widget . ShowFlairs && post .Flair != "" {
253
267
forumPost .Tags = append (forumPost .Tags , post .Flair )
254
268
}
255
269
256
270
if len (post .ParentList ) > 0 {
257
271
forumPost .IsCrosspost = true
258
272
forumPost .TargetUrlDomain = "r/" + post .ParentList [0 ].Subreddit
259
273
260
- if commentsUrlTemplate == "" {
274
+ if widget . CommentsUrlTemplate == "" {
261
275
forumPost .TargetUrl = "https://www.reddit.com" + post .ParentList [0 ].Permalink
262
276
} else {
263
277
forumPost .TargetUrl = templateRedditCommentsURL (
264
- commentsUrlTemplate ,
278
+ widget . CommentsUrlTemplate ,
265
279
post .ParentList [0 ].Subreddit ,
266
280
post .ParentList [0 ].Id ,
267
281
post .ParentList [0 ].Permalink ,
@@ -274,3 +288,55 @@ func fetchSubredditPosts(
274
288
275
289
return posts , nil
276
290
}
291
+
292
+ func (widget * redditWidget ) fetchRedditAccessToken () (err error ) {
293
+ // Only execute if all three parameters are set
294
+ if widget .RedditAppName == "" || widget .RedditClientID == "" || widget .RedditClientSecret == "" {
295
+ return nil
296
+ }
297
+
298
+ auth := base64 .StdEncoding .EncodeToString ([]byte (widget .RedditClientID + ":" + widget .RedditClientSecret ))
299
+
300
+ data := url.Values {"grant_type" : {"client_credentials" }}
301
+
302
+ req , err := http .NewRequest ("POST" , "https://www.reddit.com/api/v1/access_token" , strings .NewReader (data .Encode ()))
303
+ if err != nil {
304
+ return fmt .Errorf ("requesting an access token to the Reddit API: %w" , err )
305
+ }
306
+
307
+ req .Header .Add ("Authorization" , "Basic " + auth )
308
+ req .Header .Add ("User-Agent" , widget .RedditAppName + "/1.0" )
309
+ req .Header .Add ("Content-Type" , "application/x-www-form-urlencoded" )
310
+
311
+ client := & http.Client {
312
+ Timeout : time .Second * 10 ,
313
+ }
314
+
315
+ resp , err := client .Do (req )
316
+ if err != nil {
317
+ return fmt .Errorf ("querying Reddit API: %w" , err )
318
+ }
319
+
320
+ defer func () {
321
+ err = errors .Join (err , resp .Body .Close ())
322
+ }()
323
+
324
+ body , err := io .ReadAll (resp .Body )
325
+ if err != nil {
326
+ return fmt .Errorf ("reading response body: %w" , err )
327
+ }
328
+
329
+ if resp .StatusCode != http .StatusOK {
330
+ return fmt .Errorf ("API request failed with status %d: %s" , resp .StatusCode , string (body ))
331
+ }
332
+
333
+ var tokenResp redditTokenResponse
334
+ err = json .Unmarshal (body , & tokenResp )
335
+ if err != nil {
336
+ return fmt .Errorf ("unmarshalling Reddit API response: %w" , err )
337
+ }
338
+
339
+ widget .redditAccessToken = tokenResp .AccessToken
340
+
341
+ return
342
+ }
0 commit comments