Merge pull request from zach-klippenstein/fix-openread-errors

OpenRead will now return an error if the server returns an error when trying to open a file.
This commit is contained in:
Zach Klippenstein 2015-12-28 19:57:39 -08:00
commit d32cf6b500
15 changed files with 271 additions and 205 deletions

View file

@ -131,7 +131,7 @@ func pull(showProgress bool, remotePath, localPath string, device goadb.DeviceDe
remoteFile, err := client.OpenRead(remotePath) remoteFile, err := client.OpenRead(remotePath)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "error opening remote file %s: %s\n", remotePath, err) fmt.Fprintf(os.Stderr, "error opening remote file %s: %s\n", remotePath, util.ErrorWithCauseChain(err))
return 1 return 1
} }
defer remoteFile.Close() defer remoteFile.Close()

View file

@ -59,7 +59,7 @@ func doCommand(cmd string) error {
return err return err
} }
status, err := conn.ReadStatus() status, err := conn.ReadStatus("")
if err != nil { if err != nil {
return err return err
} }

View file

@ -112,7 +112,7 @@ func (c *DeviceClient) RunCommand(cmd string, args ...string) (string, error) {
if err = conn.SendMessage([]byte(req)); err != nil { if err = conn.SendMessage([]byte(req)); err != nil {
return "", wrapClientError(err, c, "RunCommand") return "", wrapClientError(err, c, "RunCommand")
} }
if err = wire.ReadStatusFailureAsError(conn, req); err != nil { if _, err = conn.ReadStatus(req); err != nil {
return "", wrapClientError(err, c, "RunCommand") return "", wrapClientError(err, c, "RunCommand")
} }
@ -192,7 +192,7 @@ func (c *DeviceClient) getSyncConn() (*wire.SyncConn, error) {
if err := wire.SendMessageString(conn, "sync:"); err != nil { if err := wire.SendMessageString(conn, "sync:"); err != nil {
return nil, err return nil, err
} }
if err := wire.ReadStatusFailureAsError(conn, "sync"); err != nil { if _, err := conn.ReadStatus("sync"); err != nil {
return nil, err return nil, err
} }
@ -213,7 +213,7 @@ func (c *DeviceClient) dialDevice() (*wire.Conn, error) {
return nil, util.WrapErrf(err, "error connecting to device '%s'", c.descriptor) return nil, util.WrapErrf(err, "error connecting to device '%s'", c.descriptor)
} }
if err = wire.ReadStatusFailureAsError(conn, req); err != nil { if _, err = conn.ReadStatus(req); err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
} }

View file

@ -180,7 +180,7 @@ func connectToTrackDevices(dialer Dialer) (wire.Scanner, error) {
return nil, err return nil, err
} }
if err := wire.ReadStatusFailureAsError(conn, "host:track-devices"); err != nil { if _, err := conn.ReadStatus("host:track-devices"); err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
} }

View file

@ -74,16 +74,16 @@ func (entries *DirEntries) Close() error {
} }
func readNextDirListEntry(s wire.SyncScanner) (entry *DirEntry, done bool, err error) { func readNextDirListEntry(s wire.SyncScanner) (entry *DirEntry, done bool, err error) {
id, err := s.ReadOctetString() status, err := s.ReadStatus("dir-entry")
if err != nil { if err != nil {
return return
} }
if id == "DONE" { if status == "DONE" {
done = true done = true
return return
} else if id != "DENT" { } else if status != "DENT" {
err = fmt.Errorf("error reading dir entries: expected dir entry ID 'DENT', but got '%s'", id) err = fmt.Errorf("error reading dir entries: expected dir entry ID 'DENT', but got '%s'", status)
return return
} }

View file

@ -32,7 +32,7 @@ type MockServer struct {
// but not returned. // but not returned.
Errs []error Errs []error
Status wire.StatusCode Status string
// Messages are returned from read calls in order, each preceded by a length header. // Messages are returned from read calls in order, each preceded by a length header.
Messages []string Messages []string
@ -53,7 +53,7 @@ func (s *MockServer) Dial() (*wire.Conn, error) {
return wire.NewConn(s, s), nil return wire.NewConn(s, s), nil
} }
func (s *MockServer) ReadStatus() (wire.StatusCode, error) { func (s *MockServer) ReadStatus(req string) (string, error) {
s.logMethod("ReadStatus") s.logMethod("ReadStatus")
if err := s.getNextErrToReturn(); err != nil { if err := s.getNextErrToReturn(); err != nil {
return "", err return "", err

View file

@ -20,7 +20,7 @@ func stat(conn *wire.SyncConn, path string) (*DirEntry, error) {
return nil, err return nil, err
} }
id, err := conn.ReadOctetString() id, err := conn.ReadStatus("stat")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -49,8 +49,7 @@ func receiveFile(conn *wire.SyncConn, path string) (io.ReadCloser, error) {
if err := conn.SendString(path); err != nil { if err := conn.SendString(path); err != nil {
return nil, err return nil, err
} }
return newSyncFileReader(conn)
return newSyncFileReader(conn), nil
} }
func readStat(s wire.SyncScanner) (entry *DirEntry, err error) { func readStat(s wire.SyncScanner) (entry *DirEntry, err error) {

View file

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"io" "io"
"github.com/zach-klippenstein/goadb/util"
"github.com/zach-klippenstein/goadb/wire" "github.com/zach-klippenstein/goadb/wire"
) )
@ -18,10 +19,17 @@ type syncFileReader struct {
var _ io.ReadCloser = &syncFileReader{} var _ io.ReadCloser = &syncFileReader{}
func newSyncFileReader(s wire.SyncScanner) io.ReadCloser { func newSyncFileReader(s wire.SyncScanner) (r io.ReadCloser, err error) {
return &syncFileReader{ r = &syncFileReader{
scanner: s, scanner: s,
} }
// Read the header for the first chunk to consume any errors.
if _, err = r.Read([]byte{}); err != nil {
r.Close()
return nil, err
}
return
} }
func (r *syncFileReader) Read(buf []byte) (n int, err error) { func (r *syncFileReader) Read(buf []byte) (n int, err error) {
@ -35,6 +43,13 @@ func (r *syncFileReader) Read(buf []byte) (n int, err error) {
r.chunkReader = chunkReader r.chunkReader = chunkReader
} }
if len(buf) == 0 {
// Read can be called with an empty buffer to read the next chunk and check for errors.
// However, net.Conn.Read seems to return EOF when given an empty buffer, so we need to
// handle that case ourselves.
return 0, nil
}
n, err = r.chunkReader.Read(buf) n, err = r.chunkReader.Read(buf)
if err == io.EOF { if err == io.EOF {
// End of current chunk, don't return an error, the next chunk will be // End of current chunk, don't return an error, the next chunk will be
@ -53,17 +68,27 @@ func (r *syncFileReader) Close() error {
// readNextChunk creates an io.LimitedReader for the next chunk of data, // readNextChunk creates an io.LimitedReader for the next chunk of data,
// and returns io.EOF if the last chunk has been read. // and returns io.EOF if the last chunk has been read.
func readNextChunk(r wire.SyncScanner) (io.Reader, error) { func readNextChunk(r wire.SyncScanner) (io.Reader, error) {
id, err := r.ReadOctetString() status, err := r.ReadStatus("read-chunk")
if err != nil { if err != nil {
if wire.IsAdbServerErrorMatching(err, readFileNotFoundPredicate) {
return nil, util.Errorf(util.FileNoExistError, "no such file or directory")
}
return nil, err return nil, err
} }
switch id { switch status {
case "DATA": case wire.StatusSyncData:
return r.ReadBytes() return r.ReadBytes()
case "DONE": case wire.StatusSyncDone:
return nil, io.EOF return nil, io.EOF
default: default:
return nil, fmt.Errorf("expected chunk id 'DATA', but got '%s'", id) return nil, fmt.Errorf("expected chunk id '%s' or '%s', but got '%s'",
wire.StatusSyncData, wire.StatusSyncDone, []byte(status))
} }
} }
// readFileNotFoundPredicate returns true if s is the adb server error message returned
// when trying to open a file that doesn't exist.
func readFileNotFoundPredicate(s string) bool {
return s == "No such file or directory"
}

View file

@ -6,6 +6,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zach-klippenstein/goadb/util"
"github.com/zach-klippenstein/goadb/wire" "github.com/zach-klippenstein/goadb/wire"
) )
@ -43,16 +44,17 @@ func TestReadNextChunkInvalidChunkId(t *testing.T) {
// Read 1st chunk // Read 1st chunk
_, err := readNextChunk(s) _, err := readNextChunk(s)
assert.EqualError(t, err, "expected chunk id 'DATA', but got 'ATAD'") assert.EqualError(t, err, "expected chunk id 'DATA' or 'DONE', but got 'ATAD'")
} }
func TestReadMultipleCalls(t *testing.T) { func TestReadMultipleCalls(t *testing.T) {
s := wire.NewSyncScanner(strings.NewReader( s := wire.NewSyncScanner(strings.NewReader(
"DATA\006\000\000\000hello DATA\005\000\000\000worldDONE")) "DATA\006\000\000\000hello DATA\005\000\000\000worldDONE"))
reader := newSyncFileReader(s) reader, err := newSyncFileReader(s)
assert.NoError(t, err)
firstByte := make([]byte, 1) firstByte := make([]byte, 1)
_, err := io.ReadFull(reader, firstByte) _, err = io.ReadFull(reader, firstByte)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "h", string(firstByte)) assert.Equal(t, "h", string(firstByte))
@ -73,10 +75,26 @@ func TestReadMultipleCalls(t *testing.T) {
func TestReadAll(t *testing.T) { func TestReadAll(t *testing.T) {
s := wire.NewSyncScanner(strings.NewReader( s := wire.NewSyncScanner(strings.NewReader(
"DATA\006\000\000\000hello DATA\005\000\000\000worldDONE")) "DATA\006\000\000\000hello DATA\005\000\000\000worldDONE"))
reader := newSyncFileReader(s) reader, err := newSyncFileReader(s)
assert.NoError(t, err)
buf := make([]byte, 20) buf := make([]byte, 20)
_, err := io.ReadFull(reader, buf) _, err = io.ReadFull(reader, buf)
assert.Equal(t, io.ErrUnexpectedEOF, err) assert.Equal(t, io.ErrUnexpectedEOF, err)
assert.Equal(t, "hello world\000", string(buf[:12])) assert.Equal(t, "hello world\000", string(buf[:12]))
} }
func TestReadError(t *testing.T) {
s := wire.NewSyncScanner(strings.NewReader(
"FAIL\004\000\000\000fail"))
_, err := newSyncFileReader(s)
assert.EqualError(t, err, "AdbError: server error for read-chunk request: fail ({Request:read-chunk ServerMsg:fail})")
}
func TestReadErrorNotFound(t *testing.T) {
s := wire.NewSyncScanner(strings.NewReader(
"FAIL\031\000\000\000No such file or directory"))
_, err := newSyncFileReader(s)
assert.True(t, util.HasErrCode(err, util.FileNoExistError))
assert.EqualError(t, err, "FileNoExistError: no such file or directory")
}

View file

@ -55,7 +55,7 @@ func (conn *Conn) RoundTripSingleResponse(req []byte) (resp []byte, err error) {
return nil, err return nil, err
} }
if err = ReadStatusFailureAsError(conn, string(req)); err != nil { if _, err = conn.ReadStatus(string(req)); err != nil {
return nil, err return nil, err
} }

View file

@ -1,6 +1,7 @@
package wire package wire
import ( import (
"encoding/binary"
"io" "io"
"io/ioutil" "io/ioutil"
"strconv" "strconv"
@ -12,16 +13,23 @@ import (
// StatusCodes are returned by the server. If the code indicates failure, the // StatusCodes are returned by the server. If the code indicates failure, the
// next message will be the error. // next message will be the error.
type StatusCode string
const ( const (
StatusSuccess StatusCode = "OKAY" StatusSuccess string = "OKAY"
StatusFailure = "FAIL" StatusFailure = "FAIL"
StatusSyncData = "DATA"
StatusSyncDone = "DONE"
StatusNone = "" StatusNone = ""
) )
func (status StatusCode) IsSuccess() bool { func isFailureStatus(status string) bool {
return status == StatusSuccess return status == StatusFailure
}
type StatusReader interface {
// Reads a 4-byte status string and returns it.
// If the status string is StatusFailure, reads the error message from the server
// and returns it as an util.AdbError.
ReadStatus(req string) (string, error)
} }
/* /*
@ -29,13 +37,12 @@ Scanner reads tokens from a server.
See Conn for more details. See Conn for more details.
*/ */
type Scanner interface { type Scanner interface {
ReadStatus() (StatusCode, error) io.Closer
StatusReader
ReadMessage() ([]byte, error) ReadMessage() ([]byte, error)
ReadUntilEof() ([]byte, error) ReadUntilEof() ([]byte, error)
NewSyncScanner() SyncScanner NewSyncScanner() SyncScanner
Close() error
} }
type realScanner struct { type realScanner struct {
@ -54,36 +61,12 @@ func ReadMessageString(s Scanner) (string, error) {
return string(msg), nil return string(msg), nil
} }
func (s *realScanner) ReadStatus() (StatusCode, error) { func (s *realScanner) ReadStatus(req string) (string, error) {
status := make([]byte, 4) return readStatusFailureAsError(s.reader, req, readHexLength)
n, err := io.ReadFull(s.reader, status)
if err != nil && err != io.ErrUnexpectedEOF {
return "", util.WrapErrorf(err, util.NetworkError, "error reading status")
} else if err == io.ErrUnexpectedEOF {
return StatusCode(status), errIncompleteMessage("status", n, 4)
}
return StatusCode(status), nil
} }
func (s *realScanner) ReadMessage() ([]byte, error) { func (s *realScanner) ReadMessage() ([]byte, error) {
var err error return readMessage(s.reader, readHexLength)
length, err := s.readLength()
if err != nil {
return nil, err
}
data := make([]byte, length)
n, err := io.ReadFull(s.reader, data)
if err != nil && err != io.ErrUnexpectedEOF {
return data, util.WrapErrorf(err, util.NetworkError, "error reading message data")
} else if err == io.ErrUnexpectedEOF {
return data, errIncompleteMessage("message data", n, length)
}
return data, nil
} }
func (s *realScanner) ReadUntilEof() ([]byte, error) { func (s *realScanner) ReadUntilEof() ([]byte, error) {
@ -102,9 +85,75 @@ func (s *realScanner) Close() error {
return util.WrapErrorf(s.reader.Close(), util.NetworkError, "error closing scanner") return util.WrapErrorf(s.reader.Close(), util.NetworkError, "error closing scanner")
} }
func (s *realScanner) readLength() (int, error) { var _ Scanner = &realScanner{}
// lengthReader is a func that readMessage uses to read message length.
// See readHexLength and readInt32.
type lengthReader func(io.Reader) (int, error)
// Reads the status, and if failure, reads the message and returns it as an error.
// If the status is success, doesn't read the message.
// req is just used to populate the AdbError, and can be nil.
// messageLengthReader is the function passed to readMessage if the status is failure.
func readStatusFailureAsError(r io.Reader, req string, messageLengthReader lengthReader) (string, error) {
status, err := readOctetString(req, r)
if err != nil {
return "", util.WrapErrorf(err, util.NetworkError, "error reading status for %s", req)
}
if isFailureStatus(status) {
msg, err := readMessage(r, messageLengthReader)
if err != nil {
return "", util.WrapErrorf(err, util.NetworkError,
"server returned error for %s, but couldn't read the error message", req)
}
return "", adbServerError(req, string(msg))
}
return status, nil
}
func readOctetString(description string, r io.Reader) (string, error) {
octet := make([]byte, 4)
n, err := io.ReadFull(r, octet)
if err == io.ErrUnexpectedEOF {
return "", errIncompleteMessage(description, n, 4)
} else if err != nil {
return "", util.WrapErrorf(err, util.NetworkError, "error reading "+description)
}
return string(octet), nil
}
// readMessage reads a length from r, then reads length bytes and returns them.
// lengthReader is the function used to read the length. Most operations encode
// length as a hex string (readHexLength), but sync operations use little-endian
// binary encoding (readInt32).
func readMessage(r io.Reader, lengthReader lengthReader) ([]byte, error) {
var err error
length, err := lengthReader(r)
if err != nil {
return nil, err
}
data := make([]byte, length)
n, err := io.ReadFull(r, data)
if err != nil && err != io.ErrUnexpectedEOF {
return data, util.WrapErrorf(err, util.NetworkError, "error reading message data")
} else if err == io.ErrUnexpectedEOF {
return data, errIncompleteMessage("message data", n, length)
}
return data, nil
}
// readHexLength reads the next 4 bytes from r as an ASCII hex-encoded length and parses them into an int.
func readHexLength(r io.Reader) (int, error) {
lengthHex := make([]byte, 4) lengthHex := make([]byte, 4)
n, err := io.ReadFull(s.reader, lengthHex) n, err := io.ReadFull(r, lengthHex)
if err != nil { if err != nil {
return 0, errIncompleteMessage("length", n, 4) return 0, errIncompleteMessage("length", n, 4)
} }
@ -122,4 +171,10 @@ func (s *realScanner) readLength() (int, error) {
return int(length), nil return int(length), nil
} }
var _ Scanner = &realScanner{} // readInt32 reads the next 4 bytes from r as a little-endian integer.
// Returns an int instead of an int32 to match the lengthReader type.
func readInt32(r io.Reader) (int, error) {
var value int32
err := binary.Read(r, binary.LittleEndian, &value)
return int(value), err
}

View file

@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"io" "io"
"io/ioutil"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -11,109 +12,120 @@ import (
) )
func TestReadStatusOkay(t *testing.T) { func TestReadStatusOkay(t *testing.T) {
s := NewScannerString("OKAYd") s := newEofReader("OKAYd")
status, err := s.ReadStatus() status, err := readStatusFailureAsError(s, "", readHexLength)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, status.IsSuccess()) assert.False(t, isFailureStatus(status))
assertNotEof(t, s) assertNotEof(t, s)
} }
func TestReadIncompleteStatus(t *testing.T) { func TestReadIncompleteStatus(t *testing.T) {
s := NewScannerString("oka") s := newEofReader("oka")
_, err := s.ReadStatus() _, err := readStatusFailureAsError(s, "", readHexLength)
assert.Equal(t, errIncompleteMessage("status", 3, 4), err) assert.EqualError(t, err, "NetworkError: error reading status for ")
assert.Equal(t, errIncompleteMessage("", 3, 4), err.(*util.Err).Cause)
assertEof(t, s)
}
func TestReadFailureIncompleteStatus(t *testing.T) {
s := newEofReader("FAIL")
_, err := readStatusFailureAsError(s, "req", readHexLength)
assert.EqualError(t, err, "NetworkError: server returned error for req, but couldn't read the error message")
assert.Error(t, err.(*util.Err).Cause)
assertEof(t, s)
}
func TestReadFailureEmptyStatus(t *testing.T) {
s := newEofReader("FAIL0000")
_, err := readStatusFailureAsError(s, "", readHexLength)
assert.EqualError(t, err, "AdbError: server error: ({Request: ServerMsg:})")
assert.NoError(t, err.(*util.Err).Cause)
assertEof(t, s)
}
func TestReadFailureStatus(t *testing.T) {
s := newEofReader("FAIL0004fail")
_, err := readStatusFailureAsError(s, "", readHexLength)
assert.EqualError(t, err, "AdbError: server error: fail ({Request: ServerMsg:fail})")
assert.NoError(t, err.(*util.Err).Cause)
assertEof(t, s)
}
func TestReadMessage(t *testing.T) {
s := newEofReader("0005hello")
msg, err := readMessage(s, readHexLength)
assert.NoError(t, err)
assert.Len(t, msg, 5)
assert.Equal(t, "hello", string(msg))
assertEof(t, s)
}
func TestReadMessageWithExtraData(t *testing.T) {
s := newEofReader("0005hellothere")
msg, err := readMessage(s, readHexLength)
assert.NoError(t, err)
assert.Len(t, msg, 5)
assert.Equal(t, "hello", string(msg))
assertNotEof(t, s)
}
func TestReadLongerMessage(t *testing.T) {
s := newEofReader("001b192.168.56.101:5555 device\n")
msg, err := readMessage(s, readHexLength)
assert.NoError(t, err)
assert.Len(t, msg, 27)
assert.Equal(t, "192.168.56.101:5555 device\n", string(msg))
assertEof(t, s)
}
func TestReadEmptyMessage(t *testing.T) {
s := newEofReader("0000")
msg, err := readMessage(s, readHexLength)
assert.NoError(t, err)
assert.Equal(t, "", string(msg))
assertEof(t, s)
}
func TestReadIncompleteMessage(t *testing.T) {
s := newEofReader("0005hel")
msg, err := readMessage(s, readHexLength)
assert.Error(t, err)
assert.Equal(t, errIncompleteMessage("message data", 3, 5), err)
assert.Equal(t, "hel\000\000", string(msg))
assertEof(t, s) assertEof(t, s)
} }
func TestReadLength(t *testing.T) { func TestReadLength(t *testing.T) {
s := NewScannerString("000a") s := newEofReader("000a")
l, err := s.readLength() l, err := readHexLength(s)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 10, l) assert.Equal(t, 10, l)
assertEof(t, s) assertEof(t, s)
} }
func TestReadIncompleteLength(t *testing.T) { func TestReadLengthIncompleteLength(t *testing.T) {
s := NewScannerString("aaa") s := newEofReader("aaa")
_, err := s.readLength() _, err := readHexLength(s)
assert.Equal(t, errIncompleteMessage("length", 3, 4), err) assert.Equal(t, errIncompleteMessage("length", 3, 4), err)
assertEof(t, s) assertEof(t, s)
} }
func TestReadMessage(t *testing.T) { func assertEof(t *testing.T, r io.Reader) {
s := NewScannerString("0005hello") msg, err := readMessage(r, readHexLength)
msg, err := ReadMessageString(s)
assert.NoError(t, err)
assert.Len(t, msg, 5)
assert.Equal(t, "hello", msg)
assertEof(t, s)
}
func TestReadMessageWithExtraData(t *testing.T) {
s := NewScannerString("0005hellothere")
msg, err := ReadMessageString(s)
assert.NoError(t, err)
assert.Len(t, msg, 5)
assert.Equal(t, "hello", msg)
assertNotEof(t, s)
}
func TestReadLongerMessage(t *testing.T) {
s := NewScannerString("001b192.168.56.101:5555 device\n")
msg, err := ReadMessageString(s)
assert.NoError(t, err)
assert.Len(t, msg, 27)
assert.Equal(t, "192.168.56.101:5555 device\n", msg)
assertEof(t, s)
}
func TestReadEmptyMessage(t *testing.T) {
s := NewScannerString("0000")
msg, err := ReadMessageString(s)
assert.NoError(t, err)
assert.Equal(t, "", msg)
assertEof(t, s)
}
func TestReadIncompleteMessage(t *testing.T) {
s := NewScannerString("0005hel")
msg, err := ReadMessageString(s)
assert.Error(t, err)
assert.Equal(t, errIncompleteMessage("message data", 3, 5), err)
assert.Equal(t, "hel\000\000", msg)
assertEof(t, s)
}
func NewScannerString(str string) *realScanner {
return NewScanner(NewEofBuffer(str)).(*realScanner)
}
// NewEofBuffer returns a bytes.Buffer of str that returns an EOF error
// at the end of input, instead of just returning 0 bytes read.
func NewEofBuffer(str string) *TestReader {
limitReader := io.LimitReader(bytes.NewBufferString(str), int64(len(str)))
bufReader := bufio.NewReader(limitReader)
return &TestReader{bufReader}
}
func assertEof(t *testing.T, s *realScanner) {
msg, err := s.ReadMessage()
assert.True(t, util.HasErrCode(err, util.ConnectionResetError)) assert.True(t, util.HasErrCode(err, util.ConnectionResetError))
assert.Nil(t, msg) assert.Nil(t, msg)
} }
func assertNotEof(t *testing.T, s *realScanner) { func assertNotEof(t *testing.T, r io.Reader) {
n, err := s.reader.Read(make([]byte, 1)) n, err := r.Read(make([]byte, 1))
assert.Equal(t, 1, n) assert.Equal(t, 1, n)
assert.NoError(t, err) assert.NoError(t, err)
} }
// TestReader is a wrapper around a bufio.Reader that implements io.Closer. // newEofBuffer returns a bytes.Buffer of str that returns an EOF error
type TestReader struct { // at the end of input, instead of just returning 0 bytes read.
*bufio.Reader func newEofReader(str string) io.ReadCloser {
} limitReader := io.LimitReader(bytes.NewBufferString(str), int64(len(str)))
bufReader := bufio.NewReader(limitReader)
func (b *TestReader) Close() error { return ioutil.NopCloser(bufReader)
// No-op.
return nil
} }

View file

@ -10,8 +10,8 @@ import (
) )
type SyncScanner interface { type SyncScanner interface {
// ReadOctetString reads a 4-byte string. io.Closer
ReadOctetString() (string, error) StatusReader
ReadInt32() (int32, error) ReadInt32() (int32, error)
ReadFileMode() (os.FileMode, error) ReadFileMode() (os.FileMode, error)
ReadTime() (time.Time, error) ReadTime() (time.Time, error)
@ -23,9 +23,6 @@ type SyncScanner interface {
// bytes (see io.LimitReader). The returned reader should be fully // bytes (see io.LimitReader). The returned reader should be fully
// read before reading anything off the Scanner again. // read before reading anything off the Scanner again.
ReadBytes() (io.Reader, error) ReadBytes() (io.Reader, error)
// Closes the underlying reader.
Close() error
} }
type realSyncScanner struct { type realSyncScanner struct {
@ -36,33 +33,13 @@ func NewSyncScanner(r io.Reader) SyncScanner {
return &realSyncScanner{r} return &realSyncScanner{r}
} }
func RequireOctetString(s SyncScanner, expected string) error { func (s *realSyncScanner) ReadStatus(req string) (string, error) {
actual, err := s.ReadOctetString() return readStatusFailureAsError(s.Reader, req, readInt32)
if err != nil {
return util.WrapErrorf(err, util.NetworkError, "expected to read '%s'", expected)
}
if actual != expected {
return util.AssertionErrorf("expected to read '%s', got '%s'", expected, actual)
}
return nil
} }
func (s *realSyncScanner) ReadOctetString() (string, error) {
octet := make([]byte, 4)
n, err := io.ReadFull(s.Reader, octet)
if err != nil && err != io.ErrUnexpectedEOF {
return "", util.WrapErrorf(err, util.NetworkError, "error reading octet string from sync scanner")
} else if err == io.ErrUnexpectedEOF {
return "", errIncompleteMessage("octet", n, 4)
}
return string(octet), nil
}
func (s *realSyncScanner) ReadInt32() (int32, error) { func (s *realSyncScanner) ReadInt32() (int32, error) {
var value int32 value, err := readInt32(s.Reader)
err := binary.Read(s.Reader, binary.LittleEndian, &value) return int32(value), util.WrapErrorf(err, util.NetworkError, "error reading int from sync scanner")
return value, util.WrapErrorf(err, util.NetworkError, "error reading int from sync scanner")
} }
func (s *realSyncScanner) ReadFileMode() (os.FileMode, error) { func (s *realSyncScanner) ReadFileMode() (os.FileMode, error) {
var value uint32 var value uint32

View file

@ -17,13 +17,6 @@ var (
someTimeEncoded = []byte{151, 208, 42, 85} someTimeEncoded = []byte{151, 208, 42, 85}
) )
func TestSyncReadOctetString(t *testing.T) {
s := NewSyncScanner(strings.NewReader("helo"))
str, err := s.ReadOctetString()
assert.NoError(t, err)
assert.Equal(t, "helo", str)
}
func TestSyncSendOctetString(t *testing.T) { func TestSyncSendOctetString(t *testing.T) {
var buf bytes.Buffer var buf bytes.Buffer
s := NewSyncSender(&buf) s := NewSyncSender(&buf)

View file

@ -21,28 +21,6 @@ type ErrorResponseDetails struct {
// Old servers send "device not found", and newer ones "device 'serial' not found". // Old servers send "device not found", and newer ones "device 'serial' not found".
var deviceNotFoundMessagePattern = regexp.MustCompile(`device( '.*')? not found`) var deviceNotFoundMessagePattern = regexp.MustCompile(`device( '.*')? not found`)
// Reads the status, and if failure, reads the message and returns it as an error.
// If the status is success, doesn't read the message.
// req is just used to populate the AdbError, and can be nil.
func ReadStatusFailureAsError(s Scanner, req string) error {
status, err := s.ReadStatus()
if err != nil {
return util.WrapErrorf(err, util.NetworkError, "error reading status for %s", req)
}
if !status.IsSuccess() {
msg, err := s.ReadMessage()
if err != nil {
return util.WrapErrorf(err, util.NetworkError,
"server returned error for %s, but couldn't read the error message", req)
}
return adbServerError(req, string(msg))
}
return nil
}
func adbServerError(request string, serverMsg string) error { func adbServerError(request string, serverMsg string) error {
var msg string var msg string
if request == "" { if request == "" {
@ -66,6 +44,15 @@ func adbServerError(request string, serverMsg string) error {
} }
} }
// IsAdbServerErrorMatching returns true if err is an *util.Err with code AdbError and for which
// predicate returns true when passed Details.ServerMsg.
func IsAdbServerErrorMatching(err error, predicate func(string) bool) bool {
if err, ok := err.(*util.Err); ok && err.Code == util.AdbError {
return predicate(err.Details.(ErrorResponseDetails).ServerMsg)
}
return false
}
func errIncompleteMessage(description string, actual int, expected int) error { func errIncompleteMessage(description string, actual int, expected int) error {
return &util.Err{ return &util.Err{
Code: util.ConnectionResetError, Code: util.ConnectionResetError,