Merge pull request #7 from zach-klippenstein/fix-openread-errors
OpenRead will now return an error if the server returns an error when trying to open a file.
This commit is contained in:
commit
d32cf6b500
|
@ -131,7 +131,7 @@ func pull(showProgress bool, remotePath, localPath string, device goadb.DeviceDe
|
|||
|
||||
remoteFile, err := client.OpenRead(remotePath)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error opening remote file %s: %s\n", remotePath, err)
|
||||
fmt.Fprintf(os.Stderr, "error opening remote file %s: %s\n", remotePath, util.ErrorWithCauseChain(err))
|
||||
return 1
|
||||
}
|
||||
defer remoteFile.Close()
|
||||
|
|
|
@ -59,7 +59,7 @@ func doCommand(cmd string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
status, err := conn.ReadStatus()
|
||||
status, err := conn.ReadStatus("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -112,7 +112,7 @@ func (c *DeviceClient) RunCommand(cmd string, args ...string) (string, error) {
|
|||
if err = conn.SendMessage([]byte(req)); err != nil {
|
||||
return "", wrapClientError(err, c, "RunCommand")
|
||||
}
|
||||
if err = wire.ReadStatusFailureAsError(conn, req); err != nil {
|
||||
if _, err = conn.ReadStatus(req); err != nil {
|
||||
return "", wrapClientError(err, c, "RunCommand")
|
||||
}
|
||||
|
||||
|
@ -192,7 +192,7 @@ func (c *DeviceClient) getSyncConn() (*wire.SyncConn, error) {
|
|||
if err := wire.SendMessageString(conn, "sync:"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := wire.ReadStatusFailureAsError(conn, "sync"); err != nil {
|
||||
if _, err := conn.ReadStatus("sync"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -213,7 +213,7 @@ func (c *DeviceClient) dialDevice() (*wire.Conn, error) {
|
|||
return nil, util.WrapErrf(err, "error connecting to device '%s'", c.descriptor)
|
||||
}
|
||||
|
||||
if err = wire.ReadStatusFailureAsError(conn, req); err != nil {
|
||||
if _, err = conn.ReadStatus(req); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -180,7 +180,7 @@ func connectToTrackDevices(dialer Dialer) (wire.Scanner, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if err := wire.ReadStatusFailureAsError(conn, "host:track-devices"); err != nil {
|
||||
if _, err := conn.ReadStatus("host:track-devices"); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -74,16 +74,16 @@ func (entries *DirEntries) Close() error {
|
|||
}
|
||||
|
||||
func readNextDirListEntry(s wire.SyncScanner) (entry *DirEntry, done bool, err error) {
|
||||
id, err := s.ReadOctetString()
|
||||
status, err := s.ReadStatus("dir-entry")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if id == "DONE" {
|
||||
if status == "DONE" {
|
||||
done = true
|
||||
return
|
||||
} else if id != "DENT" {
|
||||
err = fmt.Errorf("error reading dir entries: expected dir entry ID 'DENT', but got '%s'", id)
|
||||
} else if status != "DENT" {
|
||||
err = fmt.Errorf("error reading dir entries: expected dir entry ID 'DENT', but got '%s'", status)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ type MockServer struct {
|
|||
// but not returned.
|
||||
Errs []error
|
||||
|
||||
Status wire.StatusCode
|
||||
Status string
|
||||
|
||||
// Messages are returned from read calls in order, each preceded by a length header.
|
||||
Messages []string
|
||||
|
@ -53,7 +53,7 @@ func (s *MockServer) Dial() (*wire.Conn, error) {
|
|||
return wire.NewConn(s, s), nil
|
||||
}
|
||||
|
||||
func (s *MockServer) ReadStatus() (wire.StatusCode, error) {
|
||||
func (s *MockServer) ReadStatus(req string) (string, error) {
|
||||
s.logMethod("ReadStatus")
|
||||
if err := s.getNextErrToReturn(); err != nil {
|
||||
return "", err
|
||||
|
|
|
@ -20,7 +20,7 @@ func stat(conn *wire.SyncConn, path string) (*DirEntry, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
id, err := conn.ReadOctetString()
|
||||
id, err := conn.ReadStatus("stat")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -49,8 +49,7 @@ func receiveFile(conn *wire.SyncConn, path string) (io.ReadCloser, error) {
|
|||
if err := conn.SendString(path); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newSyncFileReader(conn), nil
|
||||
return newSyncFileReader(conn)
|
||||
}
|
||||
|
||||
func readStat(s wire.SyncScanner) (entry *DirEntry, err error) {
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/zach-klippenstein/goadb/util"
|
||||
"github.com/zach-klippenstein/goadb/wire"
|
||||
)
|
||||
|
||||
|
@ -18,10 +19,17 @@ type syncFileReader struct {
|
|||
|
||||
var _ io.ReadCloser = &syncFileReader{}
|
||||
|
||||
func newSyncFileReader(s wire.SyncScanner) io.ReadCloser {
|
||||
return &syncFileReader{
|
||||
func newSyncFileReader(s wire.SyncScanner) (r io.ReadCloser, err error) {
|
||||
r = &syncFileReader{
|
||||
scanner: s,
|
||||
}
|
||||
|
||||
// Read the header for the first chunk to consume any errors.
|
||||
if _, err = r.Read([]byte{}); err != nil {
|
||||
r.Close()
|
||||
return nil, err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *syncFileReader) Read(buf []byte) (n int, err error) {
|
||||
|
@ -35,6 +43,13 @@ func (r *syncFileReader) Read(buf []byte) (n int, err error) {
|
|||
r.chunkReader = chunkReader
|
||||
}
|
||||
|
||||
if len(buf) == 0 {
|
||||
// Read can be called with an empty buffer to read the next chunk and check for errors.
|
||||
// However, net.Conn.Read seems to return EOF when given an empty buffer, so we need to
|
||||
// handle that case ourselves.
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
n, err = r.chunkReader.Read(buf)
|
||||
if err == io.EOF {
|
||||
// End of current chunk, don't return an error, the next chunk will be
|
||||
|
@ -53,17 +68,27 @@ func (r *syncFileReader) Close() error {
|
|||
// readNextChunk creates an io.LimitedReader for the next chunk of data,
|
||||
// and returns io.EOF if the last chunk has been read.
|
||||
func readNextChunk(r wire.SyncScanner) (io.Reader, error) {
|
||||
id, err := r.ReadOctetString()
|
||||
status, err := r.ReadStatus("read-chunk")
|
||||
if err != nil {
|
||||
if wire.IsAdbServerErrorMatching(err, readFileNotFoundPredicate) {
|
||||
return nil, util.Errorf(util.FileNoExistError, "no such file or directory")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch id {
|
||||
case "DATA":
|
||||
switch status {
|
||||
case wire.StatusSyncData:
|
||||
return r.ReadBytes()
|
||||
case "DONE":
|
||||
case wire.StatusSyncDone:
|
||||
return nil, io.EOF
|
||||
default:
|
||||
return nil, fmt.Errorf("expected chunk id 'DATA', but got '%s'", id)
|
||||
return nil, fmt.Errorf("expected chunk id '%s' or '%s', but got '%s'",
|
||||
wire.StatusSyncData, wire.StatusSyncDone, []byte(status))
|
||||
}
|
||||
}
|
||||
|
||||
// readFileNotFoundPredicate returns true if s is the adb server error message returned
|
||||
// when trying to open a file that doesn't exist.
|
||||
func readFileNotFoundPredicate(s string) bool {
|
||||
return s == "No such file or directory"
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zach-klippenstein/goadb/util"
|
||||
"github.com/zach-klippenstein/goadb/wire"
|
||||
)
|
||||
|
||||
|
@ -43,16 +44,17 @@ func TestReadNextChunkInvalidChunkId(t *testing.T) {
|
|||
|
||||
// Read 1st chunk
|
||||
_, err := readNextChunk(s)
|
||||
assert.EqualError(t, err, "expected chunk id 'DATA', but got 'ATAD'")
|
||||
assert.EqualError(t, err, "expected chunk id 'DATA' or 'DONE', but got 'ATAD'")
|
||||
}
|
||||
|
||||
func TestReadMultipleCalls(t *testing.T) {
|
||||
s := wire.NewSyncScanner(strings.NewReader(
|
||||
"DATA\006\000\000\000hello DATA\005\000\000\000worldDONE"))
|
||||
reader := newSyncFileReader(s)
|
||||
reader, err := newSyncFileReader(s)
|
||||
assert.NoError(t, err)
|
||||
|
||||
firstByte := make([]byte, 1)
|
||||
_, err := io.ReadFull(reader, firstByte)
|
||||
_, err = io.ReadFull(reader, firstByte)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "h", string(firstByte))
|
||||
|
||||
|
@ -73,10 +75,26 @@ func TestReadMultipleCalls(t *testing.T) {
|
|||
func TestReadAll(t *testing.T) {
|
||||
s := wire.NewSyncScanner(strings.NewReader(
|
||||
"DATA\006\000\000\000hello DATA\005\000\000\000worldDONE"))
|
||||
reader := newSyncFileReader(s)
|
||||
reader, err := newSyncFileReader(s)
|
||||
assert.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 20)
|
||||
_, err := io.ReadFull(reader, buf)
|
||||
_, err = io.ReadFull(reader, buf)
|
||||
assert.Equal(t, io.ErrUnexpectedEOF, err)
|
||||
assert.Equal(t, "hello world\000", string(buf[:12]))
|
||||
}
|
||||
|
||||
func TestReadError(t *testing.T) {
|
||||
s := wire.NewSyncScanner(strings.NewReader(
|
||||
"FAIL\004\000\000\000fail"))
|
||||
_, err := newSyncFileReader(s)
|
||||
assert.EqualError(t, err, "AdbError: server error for read-chunk request: fail ({Request:read-chunk ServerMsg:fail})")
|
||||
}
|
||||
|
||||
func TestReadErrorNotFound(t *testing.T) {
|
||||
s := wire.NewSyncScanner(strings.NewReader(
|
||||
"FAIL\031\000\000\000No such file or directory"))
|
||||
_, err := newSyncFileReader(s)
|
||||
assert.True(t, util.HasErrCode(err, util.FileNoExistError))
|
||||
assert.EqualError(t, err, "FileNoExistError: no such file or directory")
|
||||
}
|
||||
|
|
|
@ -55,7 +55,7 @@ func (conn *Conn) RoundTripSingleResponse(req []byte) (resp []byte, err error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if err = ReadStatusFailureAsError(conn, string(req)); err != nil {
|
||||
if _, err = conn.ReadStatus(string(req)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
131
wire/scanner.go
131
wire/scanner.go
|
@ -1,6 +1,7 @@
|
|||
package wire
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"strconv"
|
||||
|
@ -12,16 +13,23 @@ import (
|
|||
|
||||
// StatusCodes are returned by the server. If the code indicates failure, the
|
||||
// next message will be the error.
|
||||
type StatusCode string
|
||||
|
||||
const (
|
||||
StatusSuccess StatusCode = "OKAY"
|
||||
StatusSuccess string = "OKAY"
|
||||
StatusFailure = "FAIL"
|
||||
StatusSyncData = "DATA"
|
||||
StatusSyncDone = "DONE"
|
||||
StatusNone = ""
|
||||
)
|
||||
|
||||
func (status StatusCode) IsSuccess() bool {
|
||||
return status == StatusSuccess
|
||||
func isFailureStatus(status string) bool {
|
||||
return status == StatusFailure
|
||||
}
|
||||
|
||||
type StatusReader interface {
|
||||
// Reads a 4-byte status string and returns it.
|
||||
// If the status string is StatusFailure, reads the error message from the server
|
||||
// and returns it as an util.AdbError.
|
||||
ReadStatus(req string) (string, error)
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -29,13 +37,12 @@ Scanner reads tokens from a server.
|
|||
See Conn for more details.
|
||||
*/
|
||||
type Scanner interface {
|
||||
ReadStatus() (StatusCode, error)
|
||||
io.Closer
|
||||
StatusReader
|
||||
ReadMessage() ([]byte, error)
|
||||
ReadUntilEof() ([]byte, error)
|
||||
|
||||
NewSyncScanner() SyncScanner
|
||||
|
||||
Close() error
|
||||
}
|
||||
|
||||
type realScanner struct {
|
||||
|
@ -54,36 +61,12 @@ func ReadMessageString(s Scanner) (string, error) {
|
|||
return string(msg), nil
|
||||
}
|
||||
|
||||
func (s *realScanner) ReadStatus() (StatusCode, error) {
|
||||
status := make([]byte, 4)
|
||||
n, err := io.ReadFull(s.reader, status)
|
||||
|
||||
if err != nil && err != io.ErrUnexpectedEOF {
|
||||
return "", util.WrapErrorf(err, util.NetworkError, "error reading status")
|
||||
} else if err == io.ErrUnexpectedEOF {
|
||||
return StatusCode(status), errIncompleteMessage("status", n, 4)
|
||||
}
|
||||
|
||||
return StatusCode(status), nil
|
||||
func (s *realScanner) ReadStatus(req string) (string, error) {
|
||||
return readStatusFailureAsError(s.reader, req, readHexLength)
|
||||
}
|
||||
|
||||
func (s *realScanner) ReadMessage() ([]byte, error) {
|
||||
var err error
|
||||
|
||||
length, err := s.readLength()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data := make([]byte, length)
|
||||
n, err := io.ReadFull(s.reader, data)
|
||||
|
||||
if err != nil && err != io.ErrUnexpectedEOF {
|
||||
return data, util.WrapErrorf(err, util.NetworkError, "error reading message data")
|
||||
} else if err == io.ErrUnexpectedEOF {
|
||||
return data, errIncompleteMessage("message data", n, length)
|
||||
}
|
||||
return data, nil
|
||||
return readMessage(s.reader, readHexLength)
|
||||
}
|
||||
|
||||
func (s *realScanner) ReadUntilEof() ([]byte, error) {
|
||||
|
@ -102,9 +85,75 @@ func (s *realScanner) Close() error {
|
|||
return util.WrapErrorf(s.reader.Close(), util.NetworkError, "error closing scanner")
|
||||
}
|
||||
|
||||
func (s *realScanner) readLength() (int, error) {
|
||||
var _ Scanner = &realScanner{}
|
||||
|
||||
// lengthReader is a func that readMessage uses to read message length.
|
||||
// See readHexLength and readInt32.
|
||||
type lengthReader func(io.Reader) (int, error)
|
||||
|
||||
// Reads the status, and if failure, reads the message and returns it as an error.
|
||||
// If the status is success, doesn't read the message.
|
||||
// req is just used to populate the AdbError, and can be nil.
|
||||
// messageLengthReader is the function passed to readMessage if the status is failure.
|
||||
func readStatusFailureAsError(r io.Reader, req string, messageLengthReader lengthReader) (string, error) {
|
||||
status, err := readOctetString(req, r)
|
||||
if err != nil {
|
||||
return "", util.WrapErrorf(err, util.NetworkError, "error reading status for %s", req)
|
||||
}
|
||||
|
||||
if isFailureStatus(status) {
|
||||
msg, err := readMessage(r, messageLengthReader)
|
||||
if err != nil {
|
||||
return "", util.WrapErrorf(err, util.NetworkError,
|
||||
"server returned error for %s, but couldn't read the error message", req)
|
||||
}
|
||||
|
||||
return "", adbServerError(req, string(msg))
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func readOctetString(description string, r io.Reader) (string, error) {
|
||||
octet := make([]byte, 4)
|
||||
n, err := io.ReadFull(r, octet)
|
||||
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
return "", errIncompleteMessage(description, n, 4)
|
||||
} else if err != nil {
|
||||
return "", util.WrapErrorf(err, util.NetworkError, "error reading "+description)
|
||||
}
|
||||
|
||||
return string(octet), nil
|
||||
}
|
||||
|
||||
// readMessage reads a length from r, then reads length bytes and returns them.
|
||||
// lengthReader is the function used to read the length. Most operations encode
|
||||
// length as a hex string (readHexLength), but sync operations use little-endian
|
||||
// binary encoding (readInt32).
|
||||
func readMessage(r io.Reader, lengthReader lengthReader) ([]byte, error) {
|
||||
var err error
|
||||
|
||||
length, err := lengthReader(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data := make([]byte, length)
|
||||
n, err := io.ReadFull(r, data)
|
||||
|
||||
if err != nil && err != io.ErrUnexpectedEOF {
|
||||
return data, util.WrapErrorf(err, util.NetworkError, "error reading message data")
|
||||
} else if err == io.ErrUnexpectedEOF {
|
||||
return data, errIncompleteMessage("message data", n, length)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// readHexLength reads the next 4 bytes from r as an ASCII hex-encoded length and parses them into an int.
|
||||
func readHexLength(r io.Reader) (int, error) {
|
||||
lengthHex := make([]byte, 4)
|
||||
n, err := io.ReadFull(s.reader, lengthHex)
|
||||
n, err := io.ReadFull(r, lengthHex)
|
||||
if err != nil {
|
||||
return 0, errIncompleteMessage("length", n, 4)
|
||||
}
|
||||
|
@ -122,4 +171,10 @@ func (s *realScanner) readLength() (int, error) {
|
|||
return int(length), nil
|
||||
}
|
||||
|
||||
var _ Scanner = &realScanner{}
|
||||
// readInt32 reads the next 4 bytes from r as a little-endian integer.
|
||||
// Returns an int instead of an int32 to match the lengthReader type.
|
||||
func readInt32(r io.Reader) (int, error) {
|
||||
var value int32
|
||||
err := binary.Read(r, binary.LittleEndian, &value)
|
||||
return int(value), err
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -11,109 +12,120 @@ import (
|
|||
)
|
||||
|
||||
func TestReadStatusOkay(t *testing.T) {
|
||||
s := NewScannerString("OKAYd")
|
||||
status, err := s.ReadStatus()
|
||||
s := newEofReader("OKAYd")
|
||||
status, err := readStatusFailureAsError(s, "", readHexLength)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, status.IsSuccess())
|
||||
assert.False(t, isFailureStatus(status))
|
||||
assertNotEof(t, s)
|
||||
}
|
||||
|
||||
func TestReadIncompleteStatus(t *testing.T) {
|
||||
s := NewScannerString("oka")
|
||||
_, err := s.ReadStatus()
|
||||
assert.Equal(t, errIncompleteMessage("status", 3, 4), err)
|
||||
s := newEofReader("oka")
|
||||
_, err := readStatusFailureAsError(s, "", readHexLength)
|
||||
assert.EqualError(t, err, "NetworkError: error reading status for ")
|
||||
assert.Equal(t, errIncompleteMessage("", 3, 4), err.(*util.Err).Cause)
|
||||
assertEof(t, s)
|
||||
}
|
||||
|
||||
func TestReadFailureIncompleteStatus(t *testing.T) {
|
||||
s := newEofReader("FAIL")
|
||||
_, err := readStatusFailureAsError(s, "req", readHexLength)
|
||||
assert.EqualError(t, err, "NetworkError: server returned error for req, but couldn't read the error message")
|
||||
assert.Error(t, err.(*util.Err).Cause)
|
||||
assertEof(t, s)
|
||||
}
|
||||
|
||||
func TestReadFailureEmptyStatus(t *testing.T) {
|
||||
s := newEofReader("FAIL0000")
|
||||
_, err := readStatusFailureAsError(s, "", readHexLength)
|
||||
assert.EqualError(t, err, "AdbError: server error: ({Request: ServerMsg:})")
|
||||
assert.NoError(t, err.(*util.Err).Cause)
|
||||
assertEof(t, s)
|
||||
}
|
||||
|
||||
func TestReadFailureStatus(t *testing.T) {
|
||||
s := newEofReader("FAIL0004fail")
|
||||
_, err := readStatusFailureAsError(s, "", readHexLength)
|
||||
assert.EqualError(t, err, "AdbError: server error: fail ({Request: ServerMsg:fail})")
|
||||
assert.NoError(t, err.(*util.Err).Cause)
|
||||
assertEof(t, s)
|
||||
}
|
||||
|
||||
func TestReadMessage(t *testing.T) {
|
||||
s := newEofReader("0005hello")
|
||||
msg, err := readMessage(s, readHexLength)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, msg, 5)
|
||||
assert.Equal(t, "hello", string(msg))
|
||||
assertEof(t, s)
|
||||
}
|
||||
|
||||
func TestReadMessageWithExtraData(t *testing.T) {
|
||||
s := newEofReader("0005hellothere")
|
||||
msg, err := readMessage(s, readHexLength)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, msg, 5)
|
||||
assert.Equal(t, "hello", string(msg))
|
||||
assertNotEof(t, s)
|
||||
}
|
||||
|
||||
func TestReadLongerMessage(t *testing.T) {
|
||||
s := newEofReader("001b192.168.56.101:5555 device\n")
|
||||
msg, err := readMessage(s, readHexLength)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, msg, 27)
|
||||
assert.Equal(t, "192.168.56.101:5555 device\n", string(msg))
|
||||
assertEof(t, s)
|
||||
}
|
||||
|
||||
func TestReadEmptyMessage(t *testing.T) {
|
||||
s := newEofReader("0000")
|
||||
msg, err := readMessage(s, readHexLength)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "", string(msg))
|
||||
assertEof(t, s)
|
||||
}
|
||||
|
||||
func TestReadIncompleteMessage(t *testing.T) {
|
||||
s := newEofReader("0005hel")
|
||||
msg, err := readMessage(s, readHexLength)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, errIncompleteMessage("message data", 3, 5), err)
|
||||
assert.Equal(t, "hel\000\000", string(msg))
|
||||
assertEof(t, s)
|
||||
}
|
||||
|
||||
func TestReadLength(t *testing.T) {
|
||||
s := NewScannerString("000a")
|
||||
l, err := s.readLength()
|
||||
s := newEofReader("000a")
|
||||
l, err := readHexLength(s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 10, l)
|
||||
assertEof(t, s)
|
||||
}
|
||||
|
||||
func TestReadIncompleteLength(t *testing.T) {
|
||||
s := NewScannerString("aaa")
|
||||
_, err := s.readLength()
|
||||
func TestReadLengthIncompleteLength(t *testing.T) {
|
||||
s := newEofReader("aaa")
|
||||
_, err := readHexLength(s)
|
||||
assert.Equal(t, errIncompleteMessage("length", 3, 4), err)
|
||||
assertEof(t, s)
|
||||
}
|
||||
|
||||
func TestReadMessage(t *testing.T) {
|
||||
s := NewScannerString("0005hello")
|
||||
msg, err := ReadMessageString(s)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, msg, 5)
|
||||
assert.Equal(t, "hello", msg)
|
||||
assertEof(t, s)
|
||||
}
|
||||
|
||||
func TestReadMessageWithExtraData(t *testing.T) {
|
||||
s := NewScannerString("0005hellothere")
|
||||
msg, err := ReadMessageString(s)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, msg, 5)
|
||||
assert.Equal(t, "hello", msg)
|
||||
assertNotEof(t, s)
|
||||
}
|
||||
|
||||
func TestReadLongerMessage(t *testing.T) {
|
||||
s := NewScannerString("001b192.168.56.101:5555 device\n")
|
||||
msg, err := ReadMessageString(s)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, msg, 27)
|
||||
assert.Equal(t, "192.168.56.101:5555 device\n", msg)
|
||||
assertEof(t, s)
|
||||
}
|
||||
|
||||
func TestReadEmptyMessage(t *testing.T) {
|
||||
s := NewScannerString("0000")
|
||||
msg, err := ReadMessageString(s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "", msg)
|
||||
assertEof(t, s)
|
||||
}
|
||||
|
||||
func TestReadIncompleteMessage(t *testing.T) {
|
||||
s := NewScannerString("0005hel")
|
||||
msg, err := ReadMessageString(s)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, errIncompleteMessage("message data", 3, 5), err)
|
||||
assert.Equal(t, "hel\000\000", msg)
|
||||
assertEof(t, s)
|
||||
}
|
||||
|
||||
func NewScannerString(str string) *realScanner {
|
||||
return NewScanner(NewEofBuffer(str)).(*realScanner)
|
||||
}
|
||||
|
||||
// NewEofBuffer returns a bytes.Buffer of str that returns an EOF error
|
||||
// at the end of input, instead of just returning 0 bytes read.
|
||||
func NewEofBuffer(str string) *TestReader {
|
||||
limitReader := io.LimitReader(bytes.NewBufferString(str), int64(len(str)))
|
||||
bufReader := bufio.NewReader(limitReader)
|
||||
return &TestReader{bufReader}
|
||||
}
|
||||
|
||||
func assertEof(t *testing.T, s *realScanner) {
|
||||
msg, err := s.ReadMessage()
|
||||
func assertEof(t *testing.T, r io.Reader) {
|
||||
msg, err := readMessage(r, readHexLength)
|
||||
assert.True(t, util.HasErrCode(err, util.ConnectionResetError))
|
||||
assert.Nil(t, msg)
|
||||
}
|
||||
|
||||
func assertNotEof(t *testing.T, s *realScanner) {
|
||||
n, err := s.reader.Read(make([]byte, 1))
|
||||
func assertNotEof(t *testing.T, r io.Reader) {
|
||||
n, err := r.Read(make([]byte, 1))
|
||||
assert.Equal(t, 1, n)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestReader is a wrapper around a bufio.Reader that implements io.Closer.
|
||||
type TestReader struct {
|
||||
*bufio.Reader
|
||||
}
|
||||
|
||||
func (b *TestReader) Close() error {
|
||||
// No-op.
|
||||
return nil
|
||||
// newEofBuffer returns a bytes.Buffer of str that returns an EOF error
|
||||
// at the end of input, instead of just returning 0 bytes read.
|
||||
func newEofReader(str string) io.ReadCloser {
|
||||
limitReader := io.LimitReader(bytes.NewBufferString(str), int64(len(str)))
|
||||
bufReader := bufio.NewReader(limitReader)
|
||||
return ioutil.NopCloser(bufReader)
|
||||
}
|
||||
|
|
|
@ -10,8 +10,8 @@ import (
|
|||
)
|
||||
|
||||
type SyncScanner interface {
|
||||
// ReadOctetString reads a 4-byte string.
|
||||
ReadOctetString() (string, error)
|
||||
io.Closer
|
||||
StatusReader
|
||||
ReadInt32() (int32, error)
|
||||
ReadFileMode() (os.FileMode, error)
|
||||
ReadTime() (time.Time, error)
|
||||
|
@ -23,9 +23,6 @@ type SyncScanner interface {
|
|||
// bytes (see io.LimitReader). The returned reader should be fully
|
||||
// read before reading anything off the Scanner again.
|
||||
ReadBytes() (io.Reader, error)
|
||||
|
||||
// Closes the underlying reader.
|
||||
Close() error
|
||||
}
|
||||
|
||||
type realSyncScanner struct {
|
||||
|
@ -36,33 +33,13 @@ func NewSyncScanner(r io.Reader) SyncScanner {
|
|||
return &realSyncScanner{r}
|
||||
}
|
||||
|
||||
func RequireOctetString(s SyncScanner, expected string) error {
|
||||
actual, err := s.ReadOctetString()
|
||||
if err != nil {
|
||||
return util.WrapErrorf(err, util.NetworkError, "expected to read '%s'", expected)
|
||||
}
|
||||
if actual != expected {
|
||||
return util.AssertionErrorf("expected to read '%s', got '%s'", expected, actual)
|
||||
}
|
||||
return nil
|
||||
func (s *realSyncScanner) ReadStatus(req string) (string, error) {
|
||||
return readStatusFailureAsError(s.Reader, req, readInt32)
|
||||
}
|
||||
|
||||
func (s *realSyncScanner) ReadOctetString() (string, error) {
|
||||
octet := make([]byte, 4)
|
||||
n, err := io.ReadFull(s.Reader, octet)
|
||||
|
||||
if err != nil && err != io.ErrUnexpectedEOF {
|
||||
return "", util.WrapErrorf(err, util.NetworkError, "error reading octet string from sync scanner")
|
||||
} else if err == io.ErrUnexpectedEOF {
|
||||
return "", errIncompleteMessage("octet", n, 4)
|
||||
}
|
||||
|
||||
return string(octet), nil
|
||||
}
|
||||
func (s *realSyncScanner) ReadInt32() (int32, error) {
|
||||
var value int32
|
||||
err := binary.Read(s.Reader, binary.LittleEndian, &value)
|
||||
return value, util.WrapErrorf(err, util.NetworkError, "error reading int from sync scanner")
|
||||
value, err := readInt32(s.Reader)
|
||||
return int32(value), util.WrapErrorf(err, util.NetworkError, "error reading int from sync scanner")
|
||||
}
|
||||
func (s *realSyncScanner) ReadFileMode() (os.FileMode, error) {
|
||||
var value uint32
|
||||
|
|
|
@ -17,13 +17,6 @@ var (
|
|||
someTimeEncoded = []byte{151, 208, 42, 85}
|
||||
)
|
||||
|
||||
func TestSyncReadOctetString(t *testing.T) {
|
||||
s := NewSyncScanner(strings.NewReader("helo"))
|
||||
str, err := s.ReadOctetString()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "helo", str)
|
||||
}
|
||||
|
||||
func TestSyncSendOctetString(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
s := NewSyncSender(&buf)
|
||||
|
|
31
wire/util.go
31
wire/util.go
|
@ -21,28 +21,6 @@ type ErrorResponseDetails struct {
|
|||
// Old servers send "device not found", and newer ones "device 'serial' not found".
|
||||
var deviceNotFoundMessagePattern = regexp.MustCompile(`device( '.*')? not found`)
|
||||
|
||||
// Reads the status, and if failure, reads the message and returns it as an error.
|
||||
// If the status is success, doesn't read the message.
|
||||
// req is just used to populate the AdbError, and can be nil.
|
||||
func ReadStatusFailureAsError(s Scanner, req string) error {
|
||||
status, err := s.ReadStatus()
|
||||
if err != nil {
|
||||
return util.WrapErrorf(err, util.NetworkError, "error reading status for %s", req)
|
||||
}
|
||||
|
||||
if !status.IsSuccess() {
|
||||
msg, err := s.ReadMessage()
|
||||
if err != nil {
|
||||
return util.WrapErrorf(err, util.NetworkError,
|
||||
"server returned error for %s, but couldn't read the error message", req)
|
||||
}
|
||||
|
||||
return adbServerError(req, string(msg))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func adbServerError(request string, serverMsg string) error {
|
||||
var msg string
|
||||
if request == "" {
|
||||
|
@ -66,6 +44,15 @@ func adbServerError(request string, serverMsg string) error {
|
|||
}
|
||||
}
|
||||
|
||||
// IsAdbServerErrorMatching returns true if err is an *util.Err with code AdbError and for which
|
||||
// predicate returns true when passed Details.ServerMsg.
|
||||
func IsAdbServerErrorMatching(err error, predicate func(string) bool) bool {
|
||||
if err, ok := err.(*util.Err); ok && err.Code == util.AdbError {
|
||||
return predicate(err.Details.(ErrorResponseDetails).ServerMsg)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func errIncompleteMessage(description string, actual int, expected int) error {
|
||||
return &util.Err{
|
||||
Code: util.ConnectionResetError,
|
||||
|
|
Loading…
Reference in a new issue