Recursive Goroutines, what is the neatest way to tell Go to stop reading from channel?

9.3k views Asked by At

I want to know the idiomatic way to solve this (which currently throws a deadlock error), the recursion branches an unknown number of times, so I cannot simply close the channel.

http://play.golang.org/p/avLf_sQJj_

I have made it work, by passing a pointer to a number, and incrementing it, and I've looked into using Sync waitgroups. I didn't feel (and I may be wrong), that I'd came up with an elegant solution. The Go examples I have seen tend to be simple, clever and concise.

This is the last exercise from a Tour of Go, https://tour.golang.org/#73

Do you know 'how a Go programmer' would manage this? Any help would be appreciated. I'm trying to learn well from the start.

3

There are 3 answers

1
tomasz On BEST ANSWER

Instead of involving sync.WaitGroup, you could extend the result being send on a parsed url and include number of new URLs found. In your main loop you would then keep reading the results as long as there's something to collect.

In your case number of urls found would be number of go routines spawned, but it doesn't necessarily need to be. I would personally spawn more or less fixed number of fetching routines, so you don't open too many HTTP requests (or at least you have control over it). Your main loop wouldn't change then, as it doesn't care how the fetching is being executed. The important fact here is that you need to send either a result or error for each url – I've modified the code here, so it doesn't spawn new routines when the depth is already 1.

A side effect of this solution is that you can easily print the progress in your main loop.

Here is the example on playground:

http://play.golang.org/p/BRlUc6bojf

package main

import (
    "fmt"
)

type Fetcher interface {
    // Fetch returns the body of URL and
    // a slice of URLs found on that page.
    Fetch(url string) (body string, urls []string, err error)
}

type Res struct {
    url string
    body string
    found int // Number of new urls found
}

// Crawl uses fetcher to recursively crawl
// pages starting with url, to a maximum of depth.
func Crawl(url string, depth int, fetcher Fetcher, ch chan Res, errs chan error, visited map[string]bool) {
    body, urls, err := fetcher.Fetch(url)
    visited[url] = true
    if err != nil {
        errs <- err
        return
    }

    newUrls := 0    
    if depth > 1 {
        for _, u := range urls {
            if !visited[u] {
                newUrls++
                go Crawl(u, depth-1, fetcher, ch, errs, visited)
            }
        }
    }

    // Send the result along with number of urls to be fetched
    ch <- Res{url, body, newUrls}

    return
}

func main() {
    ch := make(chan Res)
    errs := make(chan error)
    visited := map[string]bool{}
    go Crawl("http://golang.org/", 4, fetcher, ch, errs, visited)
    tocollect := 1
    for n := 0; n < tocollect; n++ {
        select {
        case s := <-ch:
            fmt.Printf("found: %s %q\n", s.url, s.body)
            tocollect += s.found
        case e := <-errs:
            fmt.Println(e)
        }
    }

}

// fakeFetcher is Fetcher that returns canned results.
type fakeFetcher map[string]*fakeResult

type fakeResult struct {
    body string
    urls []string
}

func (f fakeFetcher) Fetch(url string) (string, []string, error) {
    if res, ok := f[url]; ok {
        return res.body, res.urls, nil
    }
    return "", nil, fmt.Errorf("not found: %s", url)
}

// fetcher is a populated fakeFetcher.
var fetcher = fakeFetcher{
    "http://golang.org/": &fakeResult{
        "The Go Programming Language",
        []string{
            "http://golang.org/pkg/",
            "http://golang.org/cmd/",
        },
    },
    "http://golang.org/pkg/": &fakeResult{
        "Packages",
        []string{
            "http://golang.org/",
            "http://golang.org/cmd/",
            "http://golang.org/pkg/fmt/",
            "http://golang.org/pkg/os/",
        },
    },
    "http://golang.org/pkg/fmt/": &fakeResult{
        "Package fmt",
        []string{
            "http://golang.org/",
            "http://golang.org/pkg/",
        },
    },
    "http://golang.org/pkg/os/": &fakeResult{
        "Package os",
        []string{
            "http://golang.org/",
            "http://golang.org/pkg/",
        },
    },
}

And yes, follow @jimt advice and make access to the map thread safe.

2
jimt On

Here is my interpretation of the exercise. There are many like it, but this is mine. I use sync.WaitGroup and a custom, mutex-protected map to store visited URLs. Mostly because Go's standard map type is not thread safe. I also combine the data and error channels into a single structure, which has a method doing the reading of said channels. Mostly for separation of concerns and (arguably) keeping things a little cleaner.

Example on playground:

package main

import (
    "fmt"
    "sync"
)

type Fetcher interface {
    // Fetch returns the body of URL and
    // a slice of URLs found on that page.
    Fetch(url string) (body string, urls []string, err error)
}

// Crawl uses fetcher to recursively crawl
// pages starting with url, to a maximum of depth.
func Crawl(wg *sync.WaitGroup, url string, depth int, fetcher Fetcher, cache *UrlCache, results *Results) {
    defer wg.Done()

    if depth <= 0 || !cache.AtomicSet(url) {
        return
    }

    body, urls, err := fetcher.Fetch(url)
    if err != nil {
        results.Error <- err
        return
    }

    results.Data <- [2]string{url, body}

    for _, url := range urls {
        wg.Add(1)
        go Crawl(wg, url, depth-1, fetcher, cache, results)
    }
}

func main() {
    var wg sync.WaitGroup
    cache := NewUrlCache()

    results := NewResults()
    defer results.Close()

    wg.Add(1)
    go Crawl(&wg, "http://golang.org/", 4, fetcher, cache, results)
    go results.Read()
    wg.Wait()
}

// Results defines channels which yield results for a single crawled URL.
type Results struct {
    Data  chan [2]string // url + body.
    Error chan error     // Possible fetcher error.
}

func NewResults() *Results {
    return &Results{
        Data:  make(chan [2]string, 1),
        Error: make(chan error, 1),
    }
}

func (r *Results) Close() error {
    close(r.Data)
    close(r.Error)
    return nil
}

// Read reads crawled results or errors, for as long as the channels are open.
func (r *Results) Read() {
    for {
        select {
        case data := <-r.Data:
            fmt.Println(">", data)

        case err := <-r.Error:
            fmt.Println("e", err)
        }
    }
}

// UrlCache defines a cache of URL's we've already visited.
type UrlCache struct {
    sync.Mutex
    data map[string]struct{} // Empty struct occupies 0 bytes, whereas bool takes 1 bytes.
}

func NewUrlCache() *UrlCache { return &UrlCache{data: make(map[string]struct{})} }

// AtomicSet sets the given url in the cache and returns false if it already existed.
//
// All within the same locked context. Modifying a map without synchronisation is not safe
// when done from multiple goroutines. Doing a Exists() check and Set() separately will
// create a race condition, so we must combine both in a single operation.
func (c *UrlCache) AtomicSet(url string) bool {
    c.Lock()
    _, ok := c.data[url]
    c.data[url] = struct{}{}
    c.Unlock()
    return !ok
}

// fakeFetcher is Fetcher that returns canned results.
type fakeFetcher map[string]*fakeResult

type fakeResult struct {
    body string
    urls []string
}

func (f fakeFetcher) Fetch(url string) (string, []string, error) {
    if res, ok := f[url]; ok {
        return res.body, res.urls, nil
    }
    return "", nil, fmt.Errorf("not found: %s", url)
}

// fetcher is a populated fakeFetcher.
var fetcher = fakeFetcher{
    "http://golang.org/": &fakeResult{
        "The Go Programming Language",
        []string{
            "http://golang.org/pkg/",
            "http://golang.org/cmd/",
        },
    },
    "http://golang.org/pkg/": &fakeResult{
        "Packages",
        []string{
            "http://golang.org/",
            "http://golang.org/cmd/",
            "http://golang.org/pkg/fmt/",
            "http://golang.org/pkg/os/",
        },
    },
    "http://golang.org/pkg/fmt/": &fakeResult{
        "Package fmt",
        []string{
            "http://golang.org/",
            "http://golang.org/pkg/",
        },
    },
    "http://golang.org/pkg/os/": &fakeResult{
        "Package os",
        []string{
            "http://golang.org/",
            "http://golang.org/pkg/",
        },
    },
}

This has not been tested extensively, so perhaps there are optimisations and fixes that can be applied, but it should at least give you some ideas.

0
Sunit Chatterjee On

Here is how I solved the Web Crawler exercise of the Go Tour

For tracking recursion completion in parallel execution, I have used Atomic Integer counter to keep track of how many urls are getting crawled in parallel recursions. In the main function, I wait in loop till the atomic counter is decremented back to ZERO.

For avoiding crawling the same URL again, I have used a map with Mutex to keep track of crawled urls.

Below are the code snippets for the same.

You can find the entire working code here on Github

// Safe HashSet Version
type SafeHashSet struct {
    sync.Mutex
    urls map[string]bool //Primarily we wanted use this as an hashset, so the value of map is not significant to us
}

var (
    urlSet     SafeHashSet
    urlCounter int64
)

// Adds an URL to the Set, returns true if new url was added (if not present already)
func (m *SafeHashSet) add(newUrl string) bool {
    m.Lock()
    defer m.Unlock()
    _, ok := m.urls[newUrl]
    if !ok {
        m.urls[newUrl] = true
        return true
    }
    return false
}


// Crawl uses fetcher to recursively crawl
// pages starting with url, to a maximum of depth.
func Crawl(url string, depth int, fetcher Fetcher) {

    // Decrement the atomic url counter, when this crawl function exits
    defer atomic.AddInt64(&urlCounter, -1)

    if depth <= 0 {
        return
    }

    // Don't Process a url if it is already processed
    isNewUrl := urlSet.add(url)

    if !isNewUrl {
        fmt.Printf("skip: \t%s\n", url)
        return
    }


    body, urls, err := fetcher.Fetch(url)
    if err != nil {
        fmt.Println(err)
        return
    }
    fmt.Printf("found: \t%s %q\n", url, body)

    for _, u := range urls {
        atomic.AddInt64(&urlCounter, 1)
        // Crawl parallely
        go Crawl(u, depth-1, fetcher)
    }
    return
}

func main() {
    urlSet = SafeHashSet{urls: make(map[string]bool)}

    atomic.AddInt64(&urlCounter, 1)
    go Crawl("https://golang.org/", 4, fetcher)

    for atomic.LoadInt64(&urlCounter) > 0 {
        time.Sleep(100 * time.Microsecond)
    }
    fmt.Println("Exiting")
}