Skip to content

Commit 743fdcc

Browse files
authored
Merge pull request #258 from kuba--/feature-256/deterministic-iterator
Feature 256/deterministic iterator
2 parents f3f8432 + 98087c1 commit 743fdcc

File tree

2 files changed

+90
-152
lines changed

2 files changed

+90
-152
lines changed

repository_pool.go

Lines changed: 41 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ import (
55
"io/ioutil"
66
"os"
77
"path/filepath"
8-
"runtime"
98
"strings"
109
"sync"
10+
"sync/atomic"
1111

1212
"github.com/sirupsen/logrus"
1313
"gopkg.in/src-d/go-billy-siva.v4"
@@ -194,6 +194,7 @@ var errInvalidRepoKind = errors.NewKind("invalid repo kind: %d")
194194
// GetPos retrieves a repository at a given position. If the position is
195195
// out of bounds it returns io.EOF.
196196
func (p *RepositoryPool) GetPos(pos int) (*Repository, error) {
197+
197198
if pos >= len(p.repositories) {
198199
return nil, io.EOF
199200
}
@@ -225,24 +226,25 @@ func (p *RepositoryPool) GetPos(pos int) (*Repository, error) {
225226
// RepoIter creates a new Repository iterator
226227
func (p *RepositoryPool) RepoIter() (*RepositoryIter, error) {
227228
iter := &RepositoryIter{
228-
pos: 0,
229229
pool: p,
230230
}
231+
atomic.StoreInt32(&iter.pos, 0)
231232

232233
return iter, nil
233234
}
234235

235236
// RepositoryIter iterates over all repositories in the pool
236237
type RepositoryIter struct {
237-
pos int
238+
pos int32
238239
pool *RepositoryPool
239240
}
240241

241242
// Next retrieves the next Repository. It returns io.EOF as error
242243
// when there are no more Repositories to retrieve.
243244
func (i *RepositoryIter) Next() (*Repository, error) {
244-
r, err := i.pool.GetPos(i.pos)
245-
i.pos++
245+
pos := int(atomic.LoadInt32(&i.pos))
246+
r, err := i.pool.GetPos(pos)
247+
atomic.AddInt32(&i.pos, 1)
246248

247249
return r, err
248250
}
@@ -265,19 +267,11 @@ type RowRepoIter interface {
265267
type rowRepoIter struct {
266268
mu sync.Mutex
267269

270+
currRepoIter RowRepoIter
268271
repositoryIter *RepositoryIter
269272
iter RowRepoIter
270273
session *Session
271274
ctx *sql.Context
272-
273-
wg sync.WaitGroup
274-
done chan bool
275-
err error
276-
repos chan *Repository
277-
rows chan sql.Row
278-
279-
doneMutex sync.Mutex
280-
doneClosed bool
281275
}
282276

283277
// NewRowRepoIter initializes a new repository iterator.
@@ -303,169 +297,71 @@ func NewRowRepoIter(
303297
}
304298

305299
repoIter := rowRepoIter{
300+
currRepoIter: nil,
306301
repositoryIter: rIter,
307302
iter: iter,
308303
session: s,
309304
ctx: ctx,
310-
done: make(chan bool),
311-
err: nil,
312-
repos: make(chan *Repository),
313-
rows: make(chan sql.Row),
314-
}
315-
316-
go repoIter.fillRepoChannel()
317-
318-
wNum := runtime.NumCPU()
319-
320-
for i := 0; i < wNum; i++ {
321-
repoIter.wg.Add(1)
322-
323-
go repoIter.rowReader(i)
324305
}
325306

326-
go func() {
327-
repoIter.wg.Wait()
328-
close(repoIter.rows)
329-
closeIter(&repoIter)
330-
}()
331-
332307
return &repoIter, nil
333308
}
334309

335-
func (i *rowRepoIter) setError(err error) {
310+
// Next gets the next row
311+
func (i *rowRepoIter) Next() (sql.Row, error) {
336312
i.mu.Lock()
337313
defer i.mu.Unlock()
338314

339-
i.err = err
340-
}
341-
342-
func closeIter(i *rowRepoIter) {
343-
i.doneMutex.Lock()
344-
defer i.doneMutex.Unlock()
345-
346-
if !i.doneClosed {
347-
close(i.done)
348-
i.doneClosed = true
349-
}
350-
}
351-
352-
func (i *rowRepoIter) fillRepoChannel() {
353-
defer close(i.repos)
354-
355315
for {
356316
select {
357-
case <-i.done:
358-
return
359-
360317
case <-i.ctx.Done():
361-
closeIter(i)
362-
return
318+
return nil, ErrSessionCanceled.New()
363319

364320
default:
365-
repo, err := i.repositoryIter.Next()
366-
367-
switch err {
368-
case nil:
369-
select {
370-
case <-i.done:
371-
return
321+
if i.currRepoIter == nil {
322+
repo, err := i.repositoryIter.Next()
323+
if err != nil {
324+
if err == io.EOF {
325+
return nil, io.EOF
326+
}
372327

373-
case <-i.ctx.Done():
374-
i.setError(ErrSessionCanceled.New())
375-
closeIter(i)
376-
return
328+
if i.session.SkipGitErrors {
329+
continue
330+
}
377331

378-
case i.repos <- repo:
379-
continue
332+
return nil, err
380333
}
381334

382-
case io.EOF:
383-
i.setError(io.EOF)
384-
return
385-
386-
default:
387-
if !i.session.SkipGitErrors {
388-
closeIter(i)
389-
i.setError(err)
390-
return
335+
i.currRepoIter, err = i.iter.NewIterator(repo)
336+
if err != nil {
337+
return nil, err
391338
}
392339
}
393-
}
394-
}
395-
}
396340

397-
func (i *rowRepoIter) rowReader(num int) {
398-
defer i.wg.Done()
399-
400-
for repo := range i.repos {
401-
iter, err := i.iter.NewIterator(repo)
402-
if err != nil {
403-
// guard from possible previous error
404-
select {
405-
case <-i.done:
406-
return
407-
default:
408-
i.setError(err)
409-
closeIter(i)
410-
continue
411-
}
412-
}
413-
414-
loop:
415-
for {
416-
select {
417-
case <-i.done:
418-
iter.Close()
419-
return
420-
421-
case <-i.ctx.Done():
422-
i.setError(ErrSessionCanceled.New())
423-
return
424-
425-
default:
426-
row, err := iter.Next()
427-
switch err {
428-
case nil:
429-
select {
430-
case <-i.done:
431-
iter.Close()
432-
return
433-
case i.rows <- row:
434-
}
341+
row, err := i.currRepoIter.Next()
342+
if err != nil {
343+
if err == io.EOF {
344+
i.currRepoIter.Close()
345+
i.currRepoIter = nil
346+
continue
347+
}
435348

436-
case io.EOF:
437-
iter.Close()
438-
break loop
439-
440-
default:
441-
if !i.session.SkipGitErrors {
442-
iter.Close()
443-
i.setError(err)
444-
closeIter(i)
445-
return
446-
} else {
447-
break loop
448-
}
349+
if i.session.SkipGitErrors {
350+
continue
449351
}
450-
}
451-
}
452-
}
453-
}
454352

455-
// Next gets the next row
456-
func (i *rowRepoIter) Next() (sql.Row, error) {
457-
row, ok := <-i.rows
458-
if !ok {
459-
i.mu.Lock()
460-
defer i.mu.Unlock()
353+
return nil, err
354+
}
461355

462-
return nil, i.err
356+
return row, nil
357+
}
463358
}
464-
465-
return row, nil
466359
}
467360

468361
// Close called to close the iterator
469362
func (i *rowRepoIter) Close() error {
363+
if i.currRepoIter != nil {
364+
i.currRepoIter.Close()
365+
}
470366
return i.iter.Close()
471367
}

repository_pool_test.go

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -356,17 +356,13 @@ func testCaseRepositoryErrorIter(
356356
for {
357357
_, err := repoIter.Next()
358358
if err != nil {
359+
cancel()
359360
break
360361
}
361362
}
362363
}()
363364

364-
select {
365-
case <-repoIter.done:
366-
require.Equal(retError, repoIter.err)
367-
}
368-
369-
cancel()
365+
<-repoIter.ctx.Done()
370366
}
371367

372368
func TestRepositoryErrorIter(t *testing.T) {
@@ -396,7 +392,7 @@ func TestRepositoryErrorBadRepository(t *testing.T) {
396392

397393
count++
398394

399-
return sql.NewRow("test"), nil
395+
return sql.NewRow("test " + strconv.Itoa(count)), nil
400396
}
401397

402398
iter.newIterator = newIterator
@@ -440,3 +436,49 @@ func TestRepositoryErrorBadRow(t *testing.T) {
440436
testCaseRepositoryErrorIter(t, pool, iter, errRow, false)
441437
testCaseRepositoryErrorIter(t, pool, iter, io.EOF, true)
442438
}
439+
440+
func TestRepositoryIteratorOrder(t *testing.T) {
441+
path := fixtures.Basic().ByTag("worktree").One().Worktree().Root()
442+
pool := NewRepositoryPool()
443+
pool.Add("one", path, gitRepo)
444+
445+
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
446+
ctx := sql.NewContext(timeout,
447+
sql.WithSession(NewSession(pool, WithSkipGitErrors(true))),
448+
)
449+
iter := &testErrorIter{}
450+
newIterator := func(*Repository) (RowRepoIter, error) {
451+
return iter, nil
452+
}
453+
454+
count := 0
455+
next := func() (sql.Row, error) {
456+
if count >= 10 {
457+
return nil, io.EOF
458+
}
459+
460+
count++
461+
462+
return sql.NewRow("test " + strconv.Itoa(count)), nil
463+
}
464+
iter.newIterator = newIterator
465+
iter.next = next
466+
467+
r, err := NewRowRepoIter(ctx, iter)
468+
require.NoError(t, err)
469+
470+
repoIter, ok := r.(*rowRepoIter)
471+
require.True(t, ok)
472+
473+
func() {
474+
for i := 1; i <= 10; i++ {
475+
row, err := repoIter.Next()
476+
if err != nil {
477+
break
478+
}
479+
require.Equal(t, sql.Row{"test " + strconv.Itoa(i)}, row)
480+
}
481+
}()
482+
483+
cancel()
484+
}

0 commit comments

Comments
 (0)