Skip to content

Commit

Permalink
Merge pull request #5 from rusq/get-compat
Browse files Browse the repository at this point in the history
Compatibilty with http.Get
  • Loading branch information
rusq authored May 8, 2021
2 parents 0e93acb + 6cec881 commit 75b0fa7
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 57 deletions.
99 changes: 83 additions & 16 deletions chromedl.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"context"
"io"
"io/ioutil"
"net/http"
"os"
"path/filepath"
"strings"
Expand All @@ -36,9 +37,9 @@ type Instance struct {

ctx context.Context // context with the browser

allocCancel context.CancelFunc // allocator cancel func
browserCancel context.CancelFunc // browser cancel func
lnCancel context.CancelFunc // listener cancel func
allocFn context.CancelFunc // allocator cancel func
browserFn context.CancelFunc // browser cancel func
lnCancel context.CancelFunc // listener cancel func

guidC chan string
requestIDC chan network.RequestID
Expand Down Expand Up @@ -70,6 +71,8 @@ func OptUserAgent(ua string) Option {
}
}

// New creates a new Instance, starting up the headless chrome to do the download.
// Once finished, call Stop to terminate the browser.
func New(options ...Option) (*Instance, error) {

cfg := config{
Expand All @@ -86,6 +89,10 @@ func New(options ...Option) (*Instance, error) {
allocCtx, aCancel := chromedp.NewExecAllocator(context.Background(), opts[:]...)
ctx, cCancel := chromedp.NewContext(allocCtx, chromedp.WithLogf(dlog.Printf), chromedp.WithDebugf(dlog.Debugf))

return newInstance(ctx, cfg, aCancel, cCancel)
}

func newInstance(ctx context.Context, cfg config, allocCFn, ctxCFn context.CancelFunc) (*Instance, error) {
tmpdir, err := ioutil.TempDir("", tempPrefix+"*")
if err != nil {
return nil, err
Expand All @@ -94,9 +101,9 @@ func New(options ...Option) (*Instance, error) {
bi := Instance{
cfg: cfg,

ctx: ctx,
allocCancel: aCancel,
browserCancel: cCancel,
ctx: ctx,
allocFn: allocCFn,
browserFn: ctxCFn,

guidC: make(chan string),
requestIDC: make(chan network.RequestID),
Expand All @@ -111,31 +118,60 @@ func New(options ...Option) (*Instance, error) {
return &bi, nil
}

// ErrNoChrome indicates that there's no chrome instance in the context.
var ErrNoChrome = errors.New("no chrome instance in the context")

// NewWithChromeCtx creates new Instance for existing browser instance. Stop will not terminate
// the browser, but will cancel the event listener.
func NewWithChromeCtx(taskCtx context.Context, options ...Option) (*Instance, error) {
if chrome := chromedp.FromContext(taskCtx); chrome == nil {
return nil, ErrNoChrome
}
return newInstance(taskCtx, config{}, nil, nil)
}

func (bi *Instance) Stop() error {
bi.stopListener()
// close download channels
close(bi.guidC)
close(bi.requestIDC)
bi.browserCancel()
bi.allocCancel()

// cancel contexts if cancel functions are set
if bi.allocFn != nil {
bi.browserFn()
}
if bi.allocFn != nil {
bi.allocFn()
}

// remove temporary dir with any residual files
return os.RemoveAll(bi.tmpdir)
}

// Get downloads a file from the provided uri using the chromedp capabilities.
// Download downloads a file from the provided uri using the chromedp capabilities.
// It will return the reader with the file contents (buffered), and an error if
// any. If the error is present, reader may not be nil if the file was
// downloaded and read successfully. It will store the file in the temporary
// directory once the download is complete, then buffer it and try to cleanup
// afterwards. Set the timeout on context if required, by default no timeout is
// set. Optionally one can pass the configuration options for the downloader.
func Get(ctx context.Context, uri string, opts ...Option) (io.Reader, error) {
func Download(ctx context.Context, uri string, opts ...Option) (io.Reader, error) {
bi, err := New(opts...)
if err != nil {
return nil, err
}
defer bi.Stop()
return bi.Get(ctx, uri)
return bi.Download(ctx, uri)
}

// Get is drop-in replacement for http.Get.
func Get(url string) (*http.Response, error) {
bi, err := New()
if err != nil {
return nil, err
}
defer bi.Stop()
return bi.Get(url)
}

// stopListener stops the Listener.
Expand All @@ -161,7 +197,7 @@ func (bi *Instance) startListener() {
chromedp.ListenTarget(lnctx, bi.eventHandler)
}

// eventHandler returns an Listen
// eventHandler handles the download event.
func (bi *Instance) eventHandler(v interface{}) {
switch ev := v.(type) {
case *page.EventDownloadProgress:
Expand Down Expand Up @@ -193,13 +229,43 @@ func (bi *Instance) eventHandler(v interface{}) {
}
}

func (bi *Instance) Get(ctx context.Context, uri string) (io.Reader, error) {
// Download downloads the file returning the reader with contents.
func (bi *Instance) Download(ctx context.Context, uri string) (io.Reader, error) {
return bi.download(ctx, uri)
}
func (bi *Instance) download(ctx context.Context, uri string) (*bytes.Buffer, error) {
if err := bi.navigate(ctx, uri); err != nil {
return nil, err
}
return bi.waitTransfer(ctx)
}

// Get partly emulates http.Get to some extent and is meant to be drop-in
// replacement for http.Get in the callers code.
func (bi *Instance) Get(url string) (*http.Response, error) {
return bi.get(context.Background(), url)
}
func (bi *Instance) get(ctx context.Context, url string) (*http.Response, error) {
buf, err := bi.download(ctx, url)
if err != nil {
return nil, err
}
req, _ := http.NewRequest("GET", url, nil)
resp := http.Response{
Status: http.StatusText(http.StatusOK),
StatusCode: http.StatusOK,
Proto: "HTTP/1.0",
ProtoMajor: 1,
ProtoMinor: 0,
Body: io.NopCloser(buf),
ContentLength: int64(buf.Len()),
Close: true,
Uncompressed: true,
Request: req,
}
return &resp, nil
}

func (bi *Instance) navigate(ctx context.Context, uri string) error {
var errC = make(chan error, 1)

Expand Down Expand Up @@ -236,8 +302,8 @@ func (bi *Instance) navigate(ctx context.Context, uri string) error {

// waitTransfer waits to receive the completed download from either guid channel
// or request ID channel. Then it does what it takes to open the received data,
// buffer it and return the reader.
func (bi *Instance) waitTransfer(ctx context.Context) (io.Reader, error) {
// and returns the bytes.Buffer with data.
func (bi *Instance) waitTransfer(ctx context.Context) (*bytes.Buffer, error) {
// Listening to both available channes to return the download.
var (
b []byte
Expand All @@ -256,7 +322,8 @@ func (bi *Instance) waitTransfer(ctx context.Context) (io.Reader, error) {
case reqID := <-bi.requestIDC:
b, err = bi.readRequest(reqID)
}
return bytes.NewReader(b), err

return bytes.NewBuffer(b), err
}

func (bi *Instance) readFile(name string) ([]byte, error) {
Expand Down
83 changes: 42 additions & 41 deletions chromedl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func fakeRunnerWithErr(err error) runnerFn {
}
}

func TestBrowserDL(t *testing.T) {
func TestDownload(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
serveFile(rw, r, "test.txt", []byte("test data"))
}))
Expand All @@ -50,7 +50,7 @@ func TestBrowserDL(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r, err := Get(context.Background(), tt.uri)
r, err := Download(context.Background(), tt.uri)
if (err != nil) != tt.wantErr {
t.Errorf("Get() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down Expand Up @@ -97,7 +97,7 @@ func TestMultiDL(t *testing.T) {
val := fmt.Sprintf(format, i)
iterationC <- i

r, err := bi.Get(context.Background(), srv.URL+"/"+fmt.Sprintf(filefmt, i))
r, err := bi.Download(context.Background(), srv.URL+"/"+fmt.Sprintf(filefmt, i))
if err != nil {
t.Fatalf("%+v", err)
}
Expand Down Expand Up @@ -125,9 +125,9 @@ func serveFile(w http.ResponseWriter, r *http.Request, filename string, data []b
http.ServeContent(w, r, filename, time.Now(), bytes.NewReader(data))
}

func ExampleGet() {
func ExampleDownload() {
const rbnzRates = "https://www.rbnz.govt.nz/-/media/ReserveBank/Files/Statistics/tables/b1/hb1-daily.xlsx?revision=5fa61401-a877-4607-b7ae-2e060c09935d"
r, err := Get(context.Background(), rbnzRates)
r, err := Download(context.Background(), rbnzRates)
if err != nil {
log.Fatal(err)
}
Expand Down Expand Up @@ -185,15 +185,15 @@ func TestInstance_readRequest(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bi := &Instance{
cfg: tt.fields.cfg,
ctx: tt.fields.ctx,
allocCancel: tt.fields.allocCancel,
browserCancel: tt.fields.browserCancel,
lnCancel: tt.fields.lnCancel,
guidC: tt.fields.guidC,
requestIDC: tt.fields.requestIDC,
requests: tt.fields.requests,
tmpdir: tt.fields.tmpdir,
cfg: tt.fields.cfg,
ctx: tt.fields.ctx,
allocFn: tt.fields.allocCancel,
browserFn: tt.fields.browserCancel,
lnCancel: tt.fields.lnCancel,
guidC: tt.fields.guidC,
requestIDC: tt.fields.requestIDC,
requests: tt.fields.requests,
tmpdir: tt.fields.tmpdir,
}
oldRunner := runner
defer func() {
Expand All @@ -212,6 +212,7 @@ func TestInstance_readRequest(t *testing.T) {
}
}

// genFile generates a temporary file in directory dir with contents.
func genFile(t *testing.T, dir string, contents string) string {
f, err := ioutil.TempFile(dir, "tmp*")
if err != nil {
Expand Down Expand Up @@ -270,15 +271,15 @@ func TestInstance_readFile(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
defer os.Remove(tt.args.name)
bi := &Instance{
cfg: tt.fields.cfg,
ctx: tt.fields.ctx,
allocCancel: tt.fields.allocCancel,
browserCancel: tt.fields.browserCancel,
lnCancel: tt.fields.lnCancel,
guidC: tt.fields.guidC,
requestIDC: tt.fields.requestIDC,
requests: tt.fields.requests,
tmpdir: tt.fields.tmpdir,
cfg: tt.fields.cfg,
ctx: tt.fields.ctx,
allocFn: tt.fields.allocCancel,
browserFn: tt.fields.browserCancel,
lnCancel: tt.fields.lnCancel,
guidC: tt.fields.guidC,
requestIDC: tt.fields.requestIDC,
requests: tt.fields.requests,
tmpdir: tt.fields.tmpdir,
}
got, err := bi.readFile(tt.args.name)
if (err != nil) != tt.wantErr {
Expand Down Expand Up @@ -428,15 +429,15 @@ func TestInstance_waitTransfer(t *testing.T) {
}()
runner = fakeRunnerWithErr(nil)
bi := &Instance{
cfg: tt.fields.cfg,
ctx: tt.fields.ctx,
allocCancel: tt.fields.allocCancel,
browserCancel: tt.fields.browserCancel,
lnCancel: tt.fields.lnCancel,
guidC: tt.fields.guidC,
requestIDC: tt.fields.requestIDC,
requests: tt.fields.requests,
tmpdir: tt.fields.tmpdir,
cfg: tt.fields.cfg,
ctx: tt.fields.ctx,
allocFn: tt.fields.allocCancel,
browserFn: tt.fields.browserCancel,
lnCancel: tt.fields.lnCancel,
guidC: tt.fields.guidC,
requestIDC: tt.fields.requestIDC,
requests: tt.fields.requests,
tmpdir: tt.fields.tmpdir,
}
if err := tt.init(bi); err != nil {
t.Fatalf("init failed: %s", err)
Expand Down Expand Up @@ -531,15 +532,15 @@ func TestInstance_navigate(t *testing.T) {
}()
runner = tt.runnerFn
bi := &Instance{
cfg: tt.fields.cfg,
ctx: tt.fields.ctx,
allocCancel: tt.fields.allocCancel,
browserCancel: tt.fields.browserCancel,
lnCancel: tt.fields.lnCancel,
guidC: tt.fields.guidC,
requestIDC: tt.fields.requestIDC,
requests: tt.fields.requests,
tmpdir: tt.fields.tmpdir,
cfg: tt.fields.cfg,
ctx: tt.fields.ctx,
allocFn: tt.fields.allocCancel,
browserFn: tt.fields.browserCancel,
lnCancel: tt.fields.lnCancel,
guidC: tt.fields.guidC,
requestIDC: tt.fields.requestIDC,
requests: tt.fields.requests,
tmpdir: tt.fields.tmpdir,
}
if err := bi.navigate(tt.args.ctx, tt.args.uri); (err != nil) != tt.wantErr {
t.Errorf("Instance.navigate() error = %v, wantErr %v", err, tt.wantErr)
Expand Down

0 comments on commit 75b0fa7

Please sign in to comment.