Merge pull request #6 from zach-klippenstein/sync-send

Implemented sync file send and cmd/adb push command.
This commit is contained in:
Zach Klippenstein 2015-12-28 20:03:30 -08:00
commit e6b354ce69
10 changed files with 348 additions and 49 deletions

View file

@ -27,9 +27,9 @@ var (
pullRemoteArg = pullCommand.Arg("remote", "Path of source file on device.").Required().String()
pullLocalArg = pullCommand.Arg("local", "Path of destination file.").String()
pushCommand = kingpin.Command("push", "Push a file to the device.").Hidden()
pushCommand = kingpin.Command("push", "Push a file to the device.")
pushProgressFlag = pushCommand.Flag("progress", "Show progress.").Short('p').Bool()
pushLocalArg = pushCommand.Arg("local", "Path of source file.").Required().String()
pushLocalArg = pushCommand.Arg("local", "Path of source file.").Required().File()
pushRemoteArg = pushCommand.Arg("remote", "Path of destination file on device.").Required().String()
)
@ -143,47 +143,70 @@ func pull(showProgress bool, remotePath, localPath string, device goadb.DeviceDe
}
defer localFile.Close()
var output io.Writer
var updateProgress func()
if showProgress {
output, updateProgress = createProgressBarWriter(int(info.Size), localFile)
} else {
output = localFile
updateProgress = func() {}
}
startTime := time.Now()
copied, err := io.Copy(output, remoteFile)
// Force progress update if the transfer was really fast.
updateProgress()
if err != nil {
if err := copyWithProgressAndStats(localFile, remoteFile, int(info.Size), showProgress); err != nil {
fmt.Fprintln(os.Stderr, "error pulling file:", err)
return 1
}
duration := time.Now().Sub(startTime)
rate := int64(float64(copied) / duration.Seconds())
fmt.Printf("%d B/s (%d bytes in %s)\n", rate, copied, duration)
return 0
}
func push(showProgress bool, localPath, remotePath string, device goadb.DeviceDescriptor) int {
fmt.Fprintln(os.Stderr, "not implemented")
func push(showProgress bool, localFile *os.File, remotePath string, device goadb.DeviceDescriptor) int {
if remotePath == "" {
fmt.Fprintln(os.Stderr, "error: must specify remote file")
kingpin.Usage()
return 1
}
client := goadb.NewDeviceClient(goadb.ClientConfig{}, device)
info, err := os.Stat(localFile.Name())
if err != nil {
fmt.Fprintf(os.Stderr, "error reading local file %s: %s\n", localFile.Name(), err)
return 1
}
writer, err := client.OpenWrite(remotePath, info.Mode(), info.ModTime())
if err != nil {
fmt.Fprintf(os.Stderr, "error opening remote file %s: %s\n", remotePath, err)
return 1
}
defer writer.Close()
if err := copyWithProgressAndStats(writer, localFile, int(info.Size()), showProgress); err != nil {
fmt.Fprintln(os.Stderr, "error pushing file:", err)
return 1
}
return 0
}
func createProgressBarWriter(size int, w io.Writer) (progressWriter io.Writer, update func()) {
progress := pb.New(size)
func copyWithProgressAndStats(dst io.Writer, src io.Reader, size int, showProgress bool) error {
var progress *pb.ProgressBar
if showProgress {
progress = pb.New(size)
progress.SetUnits(pb.U_BYTES)
progress.SetRefreshRate(100 * time.Millisecond)
progress.ShowSpeed = true
progress.ShowPercent = true
progress.ShowTimeLeft = true
progress.Start()
dst = io.MultiWriter(dst, progress)
}
progressWriter = io.MultiWriter(w, progress)
update = progress.Update
return
startTime := time.Now()
copied, err := io.Copy(dst, src)
if progress != nil {
// Force progress update if the transfer was really fast.
progress.Update()
}
if err != nil {
return err
}
duration := time.Now().Sub(startTime)
rate := int64(float64(copied) / duration.Seconds())
fmt.Printf("%d B/s (%d bytes in %s)\n", rate, copied, duration)
return nil
}

View file

@ -3,12 +3,18 @@ package goadb
import (
"fmt"
"io"
"os"
"strings"
"time"
"github.com/zach-klippenstein/goadb/util"
"github.com/zach-klippenstein/goadb/wire"
)
// MtimeOfClose should be passed to OpenWrite to set the file modification time to the time the Close
// method is called.
var MtimeOfClose = time.Time{}
// DeviceClient communicates with a specific Android device.
type DeviceClient struct {
config ClientConfig
@ -171,6 +177,20 @@ func (c *DeviceClient) OpenRead(path string) (io.ReadCloser, error) {
return reader, wrapClientError(err, c, "OpenRead(%s)", path)
}
// OpenWrite opens the file at path on the device, creating it with the permissions specified
// by perms if necessary, and returns a writer that writes to the file.
// The files modification time will be set to mtime when the WriterCloser is closed. The zero value
// is TimeOfClose, which will use the time the Close method is called as the modification time.
func (c *DeviceClient) OpenWrite(path string, perms os.FileMode, mtime time.Time) (io.WriteCloser, error) {
conn, err := c.getSyncConn()
if err != nil {
return nil, wrapClientError(err, c, "OpenWrite(%s)", path)
}
writer, err := sendFile(conn, path, perms, mtime)
return writer, wrapClientError(err, c, "OpenWrite(%s)", path)
}
// getAttribute returns the first message returned by the server by running
// <host-prefix>:<attr>, where host-prefix is determined from the DeviceDescriptor.
func (c *DeviceClient) getAttribute(attr string) (string, error) {

View file

@ -1,4 +1,3 @@
// TODO(z): Implement send.
package goadb
import (
@ -16,7 +15,7 @@ func stat(conn *wire.SyncConn, path string) (*DirEntry, error) {
if err := conn.SendOctetString("STAT"); err != nil {
return nil, err
}
if err := conn.SendString(path); err != nil {
if err := conn.SendBytes([]byte(path)); err != nil {
return nil, err
}
@ -35,7 +34,7 @@ func listDirEntries(conn *wire.SyncConn, path string) (entries *DirEntries, err
if err = conn.SendOctetString("LIST"); err != nil {
return
}
if err = conn.SendString(path); err != nil {
if err = conn.SendBytes([]byte(path)); err != nil {
return
}
@ -46,12 +45,29 @@ func receiveFile(conn *wire.SyncConn, path string) (io.ReadCloser, error) {
if err := conn.SendOctetString("RECV"); err != nil {
return nil, err
}
if err := conn.SendString(path); err != nil {
if err := conn.SendBytes([]byte(path)); err != nil {
return nil, err
}
return newSyncFileReader(conn)
}
// sendFile returns a WriteCloser than will write to the file at path on device.
// The file will be created with permissions specified by mode.
// The file's modified time will be set to mtime, unless mtime is 0, in which case the time the writer is
// closed will be used.
func sendFile(conn *wire.SyncConn, path string, mode os.FileMode, mtime time.Time) (io.WriteCloser, error) {
if err := conn.SendOctetString("SEND"); err != nil {
return nil, err
}
pathAndMode := encodePathAndMode(path, mode)
if err := conn.SendBytes(pathAndMode); err != nil {
return nil, err
}
return newSyncFileWriter(conn, mtime), nil
}
func readStat(s wire.SyncScanner) (entry *DirEntry, err error) {
mode, err := s.ReadFileMode()
if err != nil {

76
sync_file_writer.go Normal file
View file

@ -0,0 +1,76 @@
package goadb
import (
"fmt"
"io"
"os"
"time"
"github.com/zach-klippenstein/goadb/util"
"github.com/zach-klippenstein/goadb/wire"
)
// syncFileWriter wraps a SyncConn that has requested to send a file.
type syncFileWriter struct {
// The modification time to write in the footer.
// If 0, use the current time.
mtime time.Time
// Reader used to read data from the adb connection.
sender wire.SyncSender
}
var _ io.WriteCloser = &syncFileWriter{}
func newSyncFileWriter(s wire.SyncSender, mtime time.Time) io.WriteCloser {
return &syncFileWriter{
mtime: mtime,
sender: s,
}
}
/*
encodePathAndMode encodes a path and file mode as required for starting a send file stream.
From https://android.googlesource.com/platform/system/core/+/master/adb/SYNC.TXT:
The remote file name is split into two parts separated by the last
comma (","). The first part is the actual path, while the second is a decimal
encoded file mode containing the permissions of the file on device.
*/
func encodePathAndMode(path string, mode os.FileMode) []byte {
return []byte(fmt.Sprintf("%s,%d", path, uint32(mode.Perm())))
}
// Write writes the min of (len(buf), 64k).
func (w *syncFileWriter) Write(buf []byte) (n int, err error) {
// Writes < 64k have a one-to-one mapping to chunks.
// If buffer is larger than the max, we'll return the max size and leave it up to the
// caller to handle correctly.
if len(buf) > wire.SyncMaxChunkSize {
buf = buf[:wire.SyncMaxChunkSize]
}
if err := w.sender.SendOctetString(wire.StatusSyncData); err != nil {
return 0, err
}
if err := w.sender.SendBytes(buf); err != nil {
return 0, err
}
return len(buf), nil
}
func (w *syncFileWriter) Close() error {
if w.mtime.IsZero() {
w.mtime = time.Now()
}
if err := w.sender.SendOctetString(wire.StatusSyncDone); err != nil {
return util.WrapErrf(err, "error closing file writer")
}
if err := w.sender.SendTime(w.mtime); err != nil {
return util.WrapErrf(err, "error writing file modification time")
}
return util.WrapErrf(w.sender.Close(), "error closing FileWriter")
}

93
sync_file_writer_test.go Normal file
View file

@ -0,0 +1,93 @@
package goadb
import (
"bytes"
"testing"
"time"
"encoding/binary"
"strings"
"github.com/stretchr/testify/assert"
"github.com/zach-klippenstein/goadb/wire"
)
func TestFileWriterWriteSingleChunk(t *testing.T) {
var buf bytes.Buffer
writer := newSyncFileWriter(wire.NewSyncSender(&buf), MtimeOfClose)
n, err := writer.Write([]byte("hello"))
assert.NoError(t, err)
assert.Equal(t, 5, n)
assert.Equal(t, "DATA\005\000\000\000hello", buf.String())
}
func TestFileWriterWriteMultiChunk(t *testing.T) {
var buf bytes.Buffer
writer := newSyncFileWriter(wire.NewSyncSender(&buf), MtimeOfClose)
n, err := writer.Write([]byte("hello"))
assert.NoError(t, err)
assert.Equal(t, 5, n)
n, err = writer.Write([]byte(" world"))
assert.NoError(t, err)
assert.Equal(t, 6, n)
assert.Equal(t, "DATA\005\000\000\000helloDATA\006\000\000\000 world", buf.String())
}
func TestFileWriterWriteLargeChunk(t *testing.T) {
var buf bytes.Buffer
writer := newSyncFileWriter(wire.NewSyncSender(&buf), MtimeOfClose)
data := make([]byte, wire.SyncMaxChunkSize+1)
n, err := writer.Write(data)
assert.NoError(t, err)
assert.Equal(t, wire.SyncMaxChunkSize, n)
assert.Equal(t, 8 + wire.SyncMaxChunkSize, buf.Len())
expectedHeader := []byte("DATA0000")
binary.LittleEndian.PutUint32(expectedHeader[4:], wire.SyncMaxChunkSize)
assert.Equal(t, expectedHeader, buf.Bytes()[:8])
assert.Equal(t, string(data[:wire.SyncMaxChunkSize]), buf.String()[8:])
}
func TestFileWriterCloseEmpty(t *testing.T) {
var buf bytes.Buffer
mtime := time.Unix(1, 0)
writer := newSyncFileWriter(wire.NewSyncSender(&buf), mtime)
assert.NoError(t, writer.Close())
assert.Equal(t, "DONE\x01\x00\x00\x00", buf.String())
}
func TestFileWriterWriteClose(t *testing.T) {
var buf bytes.Buffer
mtime := time.Unix(1, 0)
writer := newSyncFileWriter(wire.NewSyncSender(&buf), mtime)
writer.Write([]byte("hello"))
assert.NoError(t, writer.Close())
assert.Equal(t, "DATA\005\000\000\000helloDONE\x01\x00\x00\x00", buf.String())
}
func TestFileWriterCloseAutoMtime(t *testing.T) {
var buf bytes.Buffer
writer := newSyncFileWriter(wire.NewSyncSender(&buf), MtimeOfClose)
assert.NoError(t, writer.Close())
assert.Len(t, buf.String(), 8)
assert.True(t, strings.HasPrefix(buf.String(), "DONE"))
mtimeBytes := buf.Bytes()[4:]
mtimeActual := time.Unix(int64(binary.LittleEndian.Uint32(mtimeBytes)), 0)
// Delta has to be a whole second since adb only supports second granularity for mtimes.
assert.WithinDuration(t, time.Now(), mtimeActual, 1*time.Second)
}

View file

@ -77,6 +77,44 @@ func WrapErrf(cause error, format string, args ...interface{}) error {
}
}
// CombineErrs returns an error that wraps all the non-nil errors passed to it.
// If all errors are nil, returns nil.
// If there's only one non-nil error, returns that error without wrapping.
// Else, returns an error with the message and code as passed, with the cause set to an error
// that contains all the non-nil errors and for which Error() returns the concatenation of all their messages.
func CombineErrs(msg string, code ErrCode, errs ...error) error {
var nonNilErrs []error
for _, err := range errs {
if err != nil {
nonNilErrs = append(nonNilErrs, err)
}
}
switch len(nonNilErrs) {
case 0:
return nil
case 1:
return nonNilErrs[0]
default:
return WrapErrorf(multiError(nonNilErrs), code, "%s", msg)
}
}
type multiError []error
func (errs multiError) Error() string {
var buf bytes.Buffer
fmt.Fprintf(&buf, "%d errors: [", len(errs))
for i, err := range errs {
buf.WriteString(err.Error())
if i < len(errs)-1 {
buf.WriteString(" ")
}
}
buf.WriteRune(']')
return buf.String()
}
/*
WrapErrorf returns an *Err that wraps another arbitrary error with an ErrCode and a message.

View file

@ -24,3 +24,19 @@ caused by err3`
assert.Equal(t, expected, ErrorWithCauseChain(err))
}
func TestCombineErrors(t *testing.T) {
assert.NoError(t, CombineErrs("hello", AdbError))
assert.NoError(t, CombineErrs("hello", AdbError, nil, nil))
err1 := errors.New("lulz")
err2 := errors.New("fail")
err := CombineErrs("hello", AdbError, nil, err1, nil)
assert.EqualError(t, err, "lulz")
err = CombineErrs("hello", AdbError, err1, err2)
assert.EqualError(t, err, "AdbError: hello")
assert.Equal(t, `AdbError: hello
caused by 2 errors: [lulz fail]`, ErrorWithCauseChain(err))
}

View file

@ -1,9 +1,10 @@
// TODO(z): Write SyncSender.SendBytes().
package wire
import "github.com/zach-klippenstein/goadb/util"
const (
// Chunks cannot be longer than 64k.
MaxChunkSize = 64 * 1024
SyncMaxChunkSize = 64 * 1024
)
/*
@ -28,3 +29,9 @@ type SyncConn struct {
SyncScanner
SyncSender
}
// Close closes both the sender and the scanner, and returns any errors.
func (c SyncConn) Close() error {
return util.CombineErrs("error closing SyncConn", util.NetworkError,
c.SyncScanner.Close(), c.SyncSender.Close())
}

View file

@ -10,14 +10,17 @@ import (
)
type SyncSender interface {
io.Closer
// SendOctetString sends a 4-byte string.
SendOctetString(string) error
SendInt32(int32) error
SendFileMode(os.FileMode) error
SendTime(time.Time) error
// Sends len(bytes) as an octet, followed by bytes.
SendString(str string) error
// Sends len(data) as an octet, followed by the bytes.
// If data is bigger than SyncMaxChunkSize, it returns an assertion error.
SendBytes(data []byte) error
}
type realSyncSender struct {
@ -54,17 +57,24 @@ func (s *realSyncSender) SendTime(t time.Time) error {
util.NetworkError, "error sending time on sync sender")
}
func (s *realSyncSender) SendString(str string) error {
length := len(str)
if length > MaxChunkSize {
func (s *realSyncSender) SendBytes(data []byte) error {
length := len(data)
if length > SyncMaxChunkSize {
// This limit might not apply to filenames, but it's big enough
// that I don't think it will be a problem.
return util.AssertionErrorf("str must be <= %d in length", MaxChunkSize)
return util.AssertionErrorf("data must be <= %d in length", SyncMaxChunkSize)
}
if err := s.SendInt32(int32(length)); err != nil {
return util.WrapErrorf(err, util.NetworkError, "error sending string length on sync sender")
return util.WrapErrorf(err, util.NetworkError, "error sending data length on sync sender")
}
return util.WrapErrorf(writeFully(s.Writer, []byte(str)),
util.NetworkError, "error sending string on sync sender")
return util.WrapErrorf(writeFully(s.Writer, data),
util.NetworkError, "error sending data on sync sender")
}
func (s *realSyncSender) Close() error {
if closer, ok := s.Writer.(io.Closer); ok {
return util.WrapErrorf(closer.Close(), util.NetworkError, "error closing sync sender")
}
return nil
}

View file

@ -60,10 +60,10 @@ func TestSyncReadStringTooShort(t *testing.T) {
assert.Equal(t, errIncompleteMessage("bytes", 1, 5), err)
}
func TestSyncSendString(t *testing.T) {
func TestSyncSendBytes(t *testing.T) {
var buf bytes.Buffer
s := NewSyncSender(&buf)
err := s.SendString("hello")
err := s.SendBytes([]byte("hello"))
assert.NoError(t, err)
assert.Equal(t, "\005\000\000\000hello", buf.String())
}