[cmd/adb] Support pulling/pushing from stdin and to stdout.

This commit is contained in:
Zach Klippenstein 2015-12-29 13:25:17 -08:00
parent e6b354ce69
commit 3e1d948164

View file

@ -5,6 +5,7 @@ import (
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"syscall"
"time" "time"
"github.com/cheggaaa/pb" "github.com/cheggaaa/pb"
@ -13,6 +14,8 @@ import (
"gopkg.in/alecthomas/kingpin.v2" "gopkg.in/alecthomas/kingpin.v2"
) )
const StdIoFilename = "-"
var ( var (
serial = kingpin.Flag("serial", "Connect to device by serial number.").Short('s').String() serial = kingpin.Flag("serial", "Connect to device by serial number.").Short('s').String()
@ -25,11 +28,11 @@ var (
pullCommand = kingpin.Command("pull", "Pull a file from the device.") pullCommand = kingpin.Command("pull", "Pull a file from the device.")
pullProgressFlag = pullCommand.Flag("progress", "Show progress.").Short('p').Bool() pullProgressFlag = pullCommand.Flag("progress", "Show progress.").Short('p').Bool()
pullRemoteArg = pullCommand.Arg("remote", "Path of source file on device.").Required().String() pullRemoteArg = pullCommand.Arg("remote", "Path of source file on device.").Required().String()
pullLocalArg = pullCommand.Arg("local", "Path of destination file.").String() pullLocalArg = pullCommand.Arg("local", "Path of destination file. If -, will write to stdout.").String()
pushCommand = kingpin.Command("push", "Push a file to the device.") pushCommand = kingpin.Command("push", "Push a file to the device.")
pushProgressFlag = pushCommand.Flag("progress", "Show progress.").Short('p').Bool() pushProgressFlag = pushCommand.Flag("progress", "Show progress.").Short('p').Bool()
pushLocalArg = pushCommand.Arg("local", "Path of source file.").Required().File() pushLocalArg = pushCommand.Arg("local", "Path of source file. If -, will read from stdin.").Required().String()
pushRemoteArg = pushCommand.Arg("remote", "Path of destination file on device.").Required().String() pushRemoteArg = pushCommand.Arg("remote", "Path of destination file on device.").Required().String()
) )
@ -136,10 +139,15 @@ func pull(showProgress bool, remotePath, localPath string, device goadb.DeviceDe
} }
defer remoteFile.Close() defer remoteFile.Close()
localFile, err := os.Create(localPath) var localFile io.WriteCloser
if err != nil { if localPath == StdIoFilename {
fmt.Fprintf(os.Stderr, "error opening local file %s: %s\n", localPath, err) localFile = os.Stdout
return 1 } else {
localFile, err = os.Create(localPath)
if err != nil {
fmt.Fprintf(os.Stderr, "error opening local file %s: %s\n", localPath, err)
return 1
}
} }
defer localFile.Close() defer localFile.Close()
@ -150,44 +158,71 @@ func pull(showProgress bool, remotePath, localPath string, device goadb.DeviceDe
return 0 return 0
} }
func push(showProgress bool, localFile *os.File, remotePath string, device goadb.DeviceDescriptor) int { func push(showProgress bool, localPath, remotePath string, device goadb.DeviceDescriptor) int {
if remotePath == "" { if remotePath == "" {
fmt.Fprintln(os.Stderr, "error: must specify remote file") fmt.Fprintln(os.Stderr, "error: must specify remote file")
kingpin.Usage() kingpin.Usage()
return 1 return 1
} }
client := goadb.NewDeviceClient(goadb.ClientConfig{}, device) var (
localFile io.ReadCloser
info, err := os.Stat(localFile.Name()) size int
if err != nil { perms os.FileMode
fmt.Fprintf(os.Stderr, "error reading local file %s: %s\n", localFile.Name(), err) mtime time.Time
return 1 )
if localPath == "" || localPath == StdIoFilename {
localFile = os.Stdin
// 0 size will hide the progress bar.
perms = os.FileMode(0660)
mtime = goadb.MtimeOfClose
} else {
var err error
localFile, err = os.Open(localPath)
if err != nil {
fmt.Fprintf(os.Stderr, "error opening local file %s: %s\n", localPath, err)
return 1
}
info, err := os.Stat(localPath)
if err != nil {
fmt.Fprintf(os.Stderr, "error reading local file %s: %s\n", localPath, err)
return 1
}
size = int(info.Size())
perms = info.Mode().Perm()
mtime = info.ModTime()
} }
defer localFile.Close()
writer, err := client.OpenWrite(remotePath, info.Mode(), info.ModTime()) client := goadb.NewDeviceClient(goadb.ClientConfig{}, device)
writer, err := client.OpenWrite(remotePath, perms, mtime)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "error opening remote file %s: %s\n", remotePath, err) fmt.Fprintf(os.Stderr, "error opening remote file %s: %s\n", remotePath, err)
return 1 return 1
} }
defer writer.Close() defer writer.Close()
if err := copyWithProgressAndStats(writer, localFile, int(info.Size()), showProgress); err != nil { if err := copyWithProgressAndStats(writer, localFile, size, showProgress); err != nil {
fmt.Fprintln(os.Stderr, "error pushing file:", err) fmt.Fprintln(os.Stderr, "error pushing file:", err)
return 1 return 1
} }
return 0 return 0
} }
// copyWithProgressAndStats copies src to dst.
// If showProgress is true and size is positive, a progress bar is shown.
// After copying, final stats about the transfer speed and size are shown.
// Progress and stats are printed to stderr.
func copyWithProgressAndStats(dst io.Writer, src io.Reader, size int, showProgress bool) error { func copyWithProgressAndStats(dst io.Writer, src io.Reader, size int, showProgress bool) error {
var progress *pb.ProgressBar var progress *pb.ProgressBar
if showProgress { if showProgress && size > 0 {
progress = pb.New(size) progress = pb.New(size)
progress.SetUnits(pb.U_BYTES) // Write to stderr in case dst is stdout.
progress.SetRefreshRate(100 * time.Millisecond) progress.Output = os.Stderr
progress.ShowSpeed = true progress.ShowSpeed = true
progress.ShowPercent = true progress.ShowPercent = true
progress.ShowTimeLeft = true progress.ShowTimeLeft = true
progress.SetUnits(pb.U_BYTES)
progress.Start() progress.Start()
dst = io.MultiWriter(dst, progress) dst = io.MultiWriter(dst, progress)
} }
@ -196,17 +231,22 @@ func copyWithProgressAndStats(dst io.Writer, src io.Reader, size int, showProgre
copied, err := io.Copy(dst, src) copied, err := io.Copy(dst, src)
if progress != nil { if progress != nil {
// Force progress update if the transfer was really fast. progress.Finish()
progress.Update()
} }
if pathErr, ok := err.(*os.PathError); ok {
if errno, ok := pathErr.Err.(syscall.Errno); ok && errno == syscall.EPIPE {
// Pipe closed. Handle this like an EOF.
err = nil
}
}
if err != nil { if err != nil {
return err return err
} }
duration := time.Now().Sub(startTime) duration := time.Now().Sub(startTime)
rate := int64(float64(copied) / duration.Seconds()) rate := int64(float64(copied) / duration.Seconds())
fmt.Printf("%d B/s (%d bytes in %s)\n", rate, copied, duration) fmt.Fprintf(os.Stderr, "%d B/s (%d bytes in %s)\n", rate, copied, duration)
return nil return nil
} }