Refactored ClientConfig into a Server interface.
* StartServer is now a method on Server. * What used to be Dialer.Dial is now Server.Dial. * Server.Dial handles trying to start the server if the initial connection fails. * Dialer now dials a network address. * All types that took a Dialer now take a Server. * Server now has tests!
This commit is contained in:
parent
11dc26d9ba
commit
9f7d11a3bc
|
@ -1,16 +0,0 @@
|
|||
package goadb
|
||||
|
||||
var (
|
||||
defaultDialer Dialer = NewDialer("", 0)
|
||||
)
|
||||
|
||||
type ClientConfig struct {
|
||||
Dialer Dialer
|
||||
}
|
||||
|
||||
func (c ClientConfig) sanitized() ClientConfig {
|
||||
if c.Dialer == nil {
|
||||
c.Dialer = defaultDialer
|
||||
}
|
||||
return c
|
||||
}
|
|
@ -36,9 +36,18 @@ var (
|
|||
pushRemoteArg = pushCommand.Arg("remote", "Path of destination file on device.").Required().String()
|
||||
)
|
||||
|
||||
var server goadb.Server
|
||||
|
||||
func main() {
|
||||
var exitCode int
|
||||
|
||||
var err error
|
||||
server, err = goadb.NewServer(goadb.ServerConfig{})
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "error:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
switch kingpin.Parse() {
|
||||
case "devices":
|
||||
exitCode = listDevices(*devicesLongFlag)
|
||||
|
@ -62,7 +71,7 @@ func parseDevice() goadb.DeviceDescriptor {
|
|||
}
|
||||
|
||||
func listDevices(long bool) int {
|
||||
client := goadb.NewHostClient(goadb.ClientConfig{})
|
||||
client := goadb.NewHostClient(server)
|
||||
devices, err := client.ListDevices()
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "error:", err)
|
||||
|
@ -99,7 +108,7 @@ func runShellCommand(commandAndArgs []string, device goadb.DeviceDescriptor) int
|
|||
args = commandAndArgs[1:]
|
||||
}
|
||||
|
||||
client := goadb.NewDeviceClient(goadb.ClientConfig{}, device)
|
||||
client := goadb.NewDeviceClient(server, device)
|
||||
output, err := client.RunCommand(command, args...)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "error:", err)
|
||||
|
@ -121,7 +130,7 @@ func pull(showProgress bool, remotePath, localPath string, device goadb.DeviceDe
|
|||
localPath = filepath.Base(remotePath)
|
||||
}
|
||||
|
||||
client := goadb.NewDeviceClient(goadb.ClientConfig{}, device)
|
||||
client := goadb.NewDeviceClient(server, device)
|
||||
|
||||
info, err := client.Stat(remotePath)
|
||||
if util.HasErrCode(err, util.FileNoExistError) {
|
||||
|
@ -194,7 +203,7 @@ func push(showProgress bool, localPath, remotePath string, device goadb.DeviceDe
|
|||
}
|
||||
defer localFile.Close()
|
||||
|
||||
client := goadb.NewDeviceClient(goadb.ClientConfig{}, device)
|
||||
client := goadb.NewDeviceClient(server, device)
|
||||
writer, err := client.OpenWrite(remotePath, perms, mtime)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error opening remote file %s: %s\n", remotePath, err)
|
||||
|
|
|
@ -12,15 +12,26 @@ import (
|
|||
"github.com/zach-klippenstein/goadb/util"
|
||||
)
|
||||
|
||||
var port = flag.Int("p", adb.AdbPort, "")
|
||||
var (
|
||||
port = flag.Int("p", adb.AdbPort, "")
|
||||
|
||||
server adb.Server
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
client := adb.NewHostClient(adb.ClientConfig{})
|
||||
|
||||
var err error
|
||||
server, err = adb.NewServer(adb.ServerConfig{
|
||||
Port: *port,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println("Starting server…")
|
||||
adb.StartServer()
|
||||
server.Start()
|
||||
|
||||
client := adb.NewHostClient(server)
|
||||
|
||||
serverVersion, err := client.GetServerVersion()
|
||||
if err != nil {
|
||||
|
@ -51,7 +62,7 @@ func main() {
|
|||
|
||||
fmt.Println()
|
||||
fmt.Println("Watching for device state changes.")
|
||||
watcher := adb.NewDeviceWatcher(adb.ClientConfig{})
|
||||
watcher := adb.NewDeviceWatcher(server)
|
||||
for event := range watcher.C() {
|
||||
fmt.Printf("\t[%s]%+v\n", time.Now(), event)
|
||||
}
|
||||
|
@ -77,7 +88,7 @@ func printErr(err error) {
|
|||
}
|
||||
|
||||
func PrintDeviceInfoAndError(descriptor adb.DeviceDescriptor) {
|
||||
device := adb.NewDeviceClient(adb.ClientConfig{}, descriptor)
|
||||
device := adb.NewDeviceClient(server, descriptor)
|
||||
if err := PrintDeviceInfo(device); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
|
|
|
@ -49,7 +49,14 @@ func readLine() string {
|
|||
}
|
||||
|
||||
func doCommand(cmd string) error {
|
||||
conn, err := goadb.NewDialer("", *port).Dial()
|
||||
server, err := goadb.NewServer(goadb.ServerConfig{
|
||||
Port: *port,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
conn, err := server.Dial()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -17,18 +17,18 @@ var MtimeOfClose = time.Time{}
|
|||
|
||||
// DeviceClient communicates with a specific Android device.
|
||||
type DeviceClient struct {
|
||||
config ClientConfig
|
||||
server Server
|
||||
descriptor DeviceDescriptor
|
||||
|
||||
// Used to get device info.
|
||||
deviceListFunc func() ([]*DeviceInfo, error)
|
||||
}
|
||||
|
||||
func NewDeviceClient(config ClientConfig, descriptor DeviceDescriptor) *DeviceClient {
|
||||
func NewDeviceClient(server Server, descriptor DeviceDescriptor) *DeviceClient {
|
||||
return &DeviceClient{
|
||||
config: config.sanitized(),
|
||||
server: server,
|
||||
descriptor: descriptor,
|
||||
deviceListFunc: NewHostClient(config).ListDevices,
|
||||
deviceListFunc: NewHostClient(server).ListDevices,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -194,7 +194,7 @@ func (c *DeviceClient) OpenWrite(path string, perms os.FileMode, mtime time.Time
|
|||
// getAttribute returns the first message returned by the server by running
|
||||
// <host-prefix>:<attr>, where host-prefix is determined from the DeviceDescriptor.
|
||||
func (c *DeviceClient) getAttribute(attr string) (string, error) {
|
||||
resp, err := roundTripSingleResponse(c.config.Dialer,
|
||||
resp, err := roundTripSingleResponse(c.server,
|
||||
fmt.Sprintf("%s:%s", c.descriptor.getHostPrefix(), attr))
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -222,7 +222,7 @@ func (c *DeviceClient) getSyncConn() (*wire.SyncConn, error) {
|
|||
// dialDevice switches the connection to communicate directly with the device
|
||||
// by requesting the transport defined by the DeviceDescriptor.
|
||||
func (c *DeviceClient) dialDevice() (*wire.Conn, error) {
|
||||
conn, err := c.config.Dialer.Dial()
|
||||
conn, err := c.server.Dial()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -13,12 +13,7 @@ func TestGetAttribute(t *testing.T) {
|
|||
Status: wire.StatusSuccess,
|
||||
Messages: []string{"value"},
|
||||
}
|
||||
client := NewDeviceClient(
|
||||
ClientConfig{
|
||||
Dialer: s,
|
||||
},
|
||||
DeviceWithSerial("serial"),
|
||||
)
|
||||
client := NewDeviceClient(s, DeviceWithSerial("serial"))
|
||||
|
||||
v, err := client.getAttribute("attr")
|
||||
assert.Equal(t, "host-serial:serial:attr", s.Requests[0])
|
||||
|
@ -60,12 +55,10 @@ func TestGetDeviceInfo(t *testing.T) {
|
|||
|
||||
func newDeviceClientWithDeviceLister(serial string, deviceLister func() ([]*DeviceInfo, error)) *DeviceClient {
|
||||
client := NewDeviceClient(
|
||||
ClientConfig{
|
||||
Dialer: &MockServer{
|
||||
&MockServer{
|
||||
Status: wire.StatusSuccess,
|
||||
Messages: []string{serial},
|
||||
},
|
||||
},
|
||||
DeviceWithSerial(serial),
|
||||
)
|
||||
client.deviceListFunc = deviceLister
|
||||
|
@ -77,12 +70,7 @@ func TestRunCommandNoArgs(t *testing.T) {
|
|||
Status: wire.StatusSuccess,
|
||||
Messages: []string{"output"},
|
||||
}
|
||||
client := NewDeviceClient(
|
||||
ClientConfig{
|
||||
Dialer: s,
|
||||
},
|
||||
AnyDevice(),
|
||||
)
|
||||
client := NewDeviceClient(s, AnyDevice())
|
||||
|
||||
v, err := client.RunCommand("cmd")
|
||||
assert.Equal(t, "host:transport-any", s.Requests[0])
|
||||
|
|
|
@ -2,10 +2,10 @@ package goadb
|
|||
|
||||
import (
|
||||
"log"
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/zach-klippenstein/goadb/util"
|
||||
|
@ -59,22 +59,18 @@ var deviceStateStrings = map[string]DeviceState{
|
|||
}
|
||||
|
||||
type deviceWatcherImpl struct {
|
||||
config ClientConfig
|
||||
server Server
|
||||
|
||||
// If an error occurs, it is stored here and eventChan is close immediately after.
|
||||
err atomic.Value
|
||||
|
||||
eventChan chan DeviceStateChangedEvent
|
||||
|
||||
// Function to start the server if it's not running or dies.
|
||||
startServer func() error
|
||||
}
|
||||
|
||||
func NewDeviceWatcher(config ClientConfig) *DeviceWatcher {
|
||||
func NewDeviceWatcher(server Server) *DeviceWatcher {
|
||||
watcher := &DeviceWatcher{&deviceWatcherImpl{
|
||||
config: config.sanitized(),
|
||||
server: server,
|
||||
eventChan: make(chan DeviceStateChangedEvent),
|
||||
startServer: StartServer,
|
||||
}}
|
||||
|
||||
runtime.SetFinalizer(watcher, func(watcher *DeviceWatcher) {
|
||||
|
@ -134,7 +130,7 @@ func publishDevices(watcher *deviceWatcherImpl) {
|
|||
finished := false
|
||||
|
||||
for {
|
||||
scanner, err := connectToTrackDevices(watcher.config.Dialer)
|
||||
scanner, err := connectToTrackDevices(watcher.server)
|
||||
if err != nil {
|
||||
watcher.reportErr(err)
|
||||
return
|
||||
|
@ -156,7 +152,7 @@ func publishDevices(watcher *deviceWatcherImpl) {
|
|||
|
||||
log.Printf("[DeviceWatcher] server died, restarting in %s…", delay)
|
||||
time.Sleep(delay)
|
||||
if err := watcher.startServer(); err != nil {
|
||||
if err := watcher.server.Start(); err != nil {
|
||||
log.Println("[DeviceWatcher] error restarting server, giving up")
|
||||
watcher.reportErr(err)
|
||||
return
|
||||
|
@ -169,8 +165,8 @@ func publishDevices(watcher *deviceWatcherImpl) {
|
|||
}
|
||||
}
|
||||
|
||||
func connectToTrackDevices(dialer Dialer) (wire.Scanner, error) {
|
||||
conn, err := dialer.Dial()
|
||||
func connectToTrackDevices(server Server) (wire.Scanner, error) {
|
||||
conn, err := server.Dial()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package goadb
|
||||
|
||||
import (
|
||||
"log"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -207,8 +206,7 @@ func TestWentOffline(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPublishDevicesRestartsServer(t *testing.T) {
|
||||
starter := &MockServerStarter{}
|
||||
dialer := &MockServer{
|
||||
server := &MockServer{
|
||||
Status: wire.StatusSuccess,
|
||||
Errs: []error{
|
||||
nil, nil, nil, // Successful dial.
|
||||
|
@ -217,34 +215,17 @@ func TestPublishDevicesRestartsServer(t *testing.T) {
|
|||
},
|
||||
}
|
||||
watcher := deviceWatcherImpl{
|
||||
config: ClientConfig{dialer},
|
||||
server: server,
|
||||
eventChan: make(chan DeviceStateChangedEvent),
|
||||
startServer: starter.StartServer,
|
||||
}
|
||||
|
||||
publishDevices(&watcher)
|
||||
|
||||
assert.Empty(t, dialer.Errs)
|
||||
assert.Equal(t, []string{"host:track-devices"}, dialer.Requests)
|
||||
assert.Equal(t, []string{"Dial", "SendMessage", "ReadStatus", "ReadMessage", "Dial"}, dialer.Trace)
|
||||
assert.Empty(t, server.Errs)
|
||||
assert.Equal(t, []string{"host:track-devices"}, server.Requests)
|
||||
assert.Equal(t, []string{"Dial", "SendMessage", "ReadStatus", "ReadMessage", "Start", "Dial"}, server.Trace)
|
||||
err := watcher.err.Load().(*util.Err)
|
||||
assert.Equal(t, util.ServerNotAvailable, err.Code)
|
||||
assert.Equal(t, 1, starter.startCount)
|
||||
}
|
||||
|
||||
type MockServerStarter struct {
|
||||
startCount int
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *MockServerStarter) StartServer() error {
|
||||
log.Printf("Starting mock server")
|
||||
if s.err == nil {
|
||||
s.startCount += 1
|
||||
return nil
|
||||
} else {
|
||||
return s.err
|
||||
}
|
||||
}
|
||||
|
||||
func assertContainsOnly(t *testing.T, expected, actual []DeviceStateChangedEvent) {
|
||||
|
|
61
dialer.go
61
dialer.go
|
@ -1,7 +1,6 @@
|
|||
package goadb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
|
@ -10,62 +9,20 @@ import (
|
|||
"github.com/zach-klippenstein/goadb/wire"
|
||||
)
|
||||
|
||||
const (
|
||||
// Default port the adb server listens on.
|
||||
AdbPort = 5037
|
||||
)
|
||||
|
||||
/*
|
||||
Dialer knows how to create connections to an adb server.
|
||||
*/
|
||||
// Dialer knows how to create connections to an adb server.
|
||||
type Dialer interface {
|
||||
Dial() (*wire.Conn, error)
|
||||
Dial(address string) (*wire.Conn, error)
|
||||
}
|
||||
|
||||
/*
|
||||
NewDialer creates a new Dialer.
|
||||
|
||||
If host is "" or port is 0, "localhost:5037" is used.
|
||||
*/
|
||||
func NewDialer(host string, port int) Dialer {
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
}
|
||||
if port == 0 {
|
||||
port = AdbPort
|
||||
}
|
||||
return &netDialer{host, port}
|
||||
}
|
||||
|
||||
type netDialer struct {
|
||||
Host string
|
||||
Port int
|
||||
}
|
||||
|
||||
func (d *netDialer) String() string {
|
||||
return fmt.Sprintf("netDialer(%s:%d)", d.Host, d.Port)
|
||||
}
|
||||
type tcpDialer struct{}
|
||||
|
||||
// Dial connects to the adb server on the host and port set on the netDialer.
|
||||
// The zero-value will connect to the default, localhost:5037.
|
||||
func (d *netDialer) Dial() (*wire.Conn, error) {
|
||||
host := d.Host
|
||||
port := d.Port
|
||||
|
||||
address := fmt.Sprintf("%s:%d", host, port)
|
||||
func (tcpDialer) Dial(address string) (*wire.Conn, error) {
|
||||
netConn, err := net.Dial("tcp", address)
|
||||
if err != nil {
|
||||
// Attempt to start the server and try again.
|
||||
if err = StartServer(); err != nil {
|
||||
return nil, util.WrapErrorf(err, util.ServerNotAvailable, "error starting server")
|
||||
}
|
||||
|
||||
address = fmt.Sprintf("%s:%d", host, port)
|
||||
netConn, err = net.Dial("tcp", address)
|
||||
if err != nil {
|
||||
return nil, util.WrapErrorf(err, util.ServerNotAvailable, "error dialing %s", address)
|
||||
}
|
||||
}
|
||||
|
||||
// net.Conn can't be closed more than once, but wire.Conn will try to close both sender and scanner
|
||||
// so we need to wrap it to make it safe.
|
||||
|
@ -84,13 +41,3 @@ func (d *netDialer) Dial() (*wire.Conn, error) {
|
|||
Sender: wire.NewSender(safeConn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func roundTripSingleResponse(d Dialer, req string) ([]byte, error) {
|
||||
conn, err := d.Dial()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return conn.RoundTripSingleResponse([]byte(req))
|
||||
}
|
||||
|
|
|
@ -19,16 +19,16 @@ See list of services at https://android.googlesource.com/platform/system/core/+/
|
|||
*/
|
||||
// TODO(z): Finish implementing host services.
|
||||
type HostClient struct {
|
||||
config ClientConfig
|
||||
server Server
|
||||
}
|
||||
|
||||
func NewHostClient(config ClientConfig) *HostClient {
|
||||
return &HostClient{config.sanitized()}
|
||||
func NewHostClient(server Server) *HostClient {
|
||||
return &HostClient{server}
|
||||
}
|
||||
|
||||
// GetServerVersion asks the ADB server for its internal version number.
|
||||
func (c *HostClient) GetServerVersion() (int, error) {
|
||||
resp, err := roundTripSingleResponse(c.config.Dialer, "host:version")
|
||||
resp, err := roundTripSingleResponse(c.server, "host:version")
|
||||
if err != nil {
|
||||
return 0, wrapClientError(err, c, "GetServerVersion")
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ Corresponds to the command:
|
|||
adb kill-server
|
||||
*/
|
||||
func (c *HostClient) KillServer() error {
|
||||
conn, err := c.config.Dialer.Dial()
|
||||
conn, err := c.server.Dial()
|
||||
if err != nil {
|
||||
return wrapClientError(err, c, "KillServer")
|
||||
}
|
||||
|
@ -67,7 +67,7 @@ Corresponds to the command:
|
|||
adb devices
|
||||
*/
|
||||
func (c *HostClient) ListDeviceSerials() ([]string, error) {
|
||||
resp, err := roundTripSingleResponse(c.config.Dialer, "host:devices")
|
||||
resp, err := roundTripSingleResponse(c.server, "host:devices")
|
||||
if err != nil {
|
||||
return nil, wrapClientError(err, c, "ListDeviceSerials")
|
||||
}
|
||||
|
@ -91,7 +91,7 @@ Corresponds to the command:
|
|||
adb devices -l
|
||||
*/
|
||||
func (c *HostClient) ListDevices() ([]*DeviceInfo, error) {
|
||||
resp, err := roundTripSingleResponse(c.config.Dialer, "host:devices-l")
|
||||
resp, err := roundTripSingleResponse(c.server, "host:devices-l")
|
||||
if err != nil {
|
||||
return nil, wrapClientError(err, c, "ListDevices")
|
||||
}
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
package goadb
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zach-klippenstein/goadb/util"
|
||||
"github.com/zach-klippenstein/goadb/wire"
|
||||
)
|
||||
|
||||
|
@ -15,113 +12,10 @@ func TestGetServerVersion(t *testing.T) {
|
|||
Status: wire.StatusSuccess,
|
||||
Messages: []string{"000a"},
|
||||
}
|
||||
client := NewHostClient(ClientConfig{
|
||||
Dialer: s,
|
||||
})
|
||||
client := NewHostClient(s)
|
||||
|
||||
v, err := client.GetServerVersion()
|
||||
assert.Equal(t, "host:version", s.Requests[0])
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 10, v)
|
||||
}
|
||||
|
||||
// MockServer implements Dialer, Scanner, and Sender.
|
||||
type MockServer struct {
|
||||
// Each time an operation is performed, if this slice is non-empty, the head element
|
||||
// of this slice is returned and removed from the slice. If the head is nil, it is removed
|
||||
// but not returned.
|
||||
Errs []error
|
||||
|
||||
Status string
|
||||
|
||||
// Messages are returned from read calls in order, each preceded by a length header.
|
||||
Messages []string
|
||||
nextMsgIndex int
|
||||
|
||||
// Each message passed to a send call is appended to this slice.
|
||||
Requests []string
|
||||
|
||||
// Each time an operaiton is performed, its name is appended to this slice.
|
||||
Trace []string
|
||||
}
|
||||
|
||||
func (s *MockServer) Dial() (*wire.Conn, error) {
|
||||
s.logMethod("Dial")
|
||||
if err := s.getNextErrToReturn(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return wire.NewConn(s, s), nil
|
||||
}
|
||||
|
||||
func (s *MockServer) ReadStatus(req string) (string, error) {
|
||||
s.logMethod("ReadStatus")
|
||||
if err := s.getNextErrToReturn(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return s.Status, nil
|
||||
}
|
||||
|
||||
func (s *MockServer) ReadMessage() ([]byte, error) {
|
||||
s.logMethod("ReadMessage")
|
||||
if err := s.getNextErrToReturn(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.nextMsgIndex >= len(s.Messages) {
|
||||
return nil, util.WrapErrorf(io.EOF, util.NetworkError, "")
|
||||
}
|
||||
|
||||
s.nextMsgIndex++
|
||||
return []byte(s.Messages[s.nextMsgIndex-1]), nil
|
||||
}
|
||||
|
||||
func (s *MockServer) ReadUntilEof() ([]byte, error) {
|
||||
s.logMethod("ReadUntilEof")
|
||||
if err := s.getNextErrToReturn(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var data []string
|
||||
for ; s.nextMsgIndex < len(s.Messages); s.nextMsgIndex++ {
|
||||
data = append(data, s.Messages[s.nextMsgIndex])
|
||||
}
|
||||
return []byte(strings.Join(data, "")), nil
|
||||
}
|
||||
|
||||
func (s *MockServer) SendMessage(msg []byte) error {
|
||||
s.logMethod("SendMessage")
|
||||
if err := s.getNextErrToReturn(); err != nil {
|
||||
return err
|
||||
}
|
||||
s.Requests = append(s.Requests, string(msg))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MockServer) NewSyncScanner() wire.SyncScanner {
|
||||
s.logMethod("NewSyncScanner")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MockServer) NewSyncSender() wire.SyncSender {
|
||||
s.logMethod("NewSyncSender")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MockServer) Close() error {
|
||||
s.logMethod("Close")
|
||||
if err := s.getNextErrToReturn(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MockServer) getNextErrToReturn() (err error) {
|
||||
if len(s.Errs) > 0 {
|
||||
err = s.Errs[0]
|
||||
s.Errs = s.Errs[1:]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *MockServer) logMethod(name string) {
|
||||
s.Trace = append(s.Trace, name)
|
||||
}
|
||||
|
|
146
server.go
Normal file
146
server.go
Normal file
|
@ -0,0 +1,146 @@
|
|||
package goadb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/zach-klippenstein/goadb/util"
|
||||
"github.com/zach-klippenstein/goadb/wire"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
AdbExecutableName = "adb"
|
||||
|
||||
// Default port the adb server listens on.
|
||||
AdbPort = 5037
|
||||
)
|
||||
|
||||
type ServerConfig struct {
|
||||
// Path to the adb executable. If empty, the PATH environment variable will be searched.
|
||||
PathToAdb string
|
||||
|
||||
// Host and port the adb server is listening on.
|
||||
// If not specified, will use the default port on localhost.
|
||||
Host string
|
||||
Port int
|
||||
|
||||
// Dialer used to connect to the adb server.
|
||||
Dialer
|
||||
}
|
||||
|
||||
// Server knows how to start the adb server and connect to it.
|
||||
type Server interface {
|
||||
Start() error
|
||||
Dial() (*wire.Conn, error)
|
||||
}
|
||||
|
||||
func roundTripSingleResponse(s Server, req string) ([]byte, error) {
|
||||
conn, err := s.Dial()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return conn.RoundTripSingleResponse([]byte(req))
|
||||
}
|
||||
|
||||
type realServer struct {
|
||||
config ServerConfig
|
||||
fs *filesystem
|
||||
|
||||
// Caches Host:Port so they don't have to be concatenated for every dial.
|
||||
address string
|
||||
}
|
||||
|
||||
// NewServer creates a new Server instance.
|
||||
func NewServer(config ServerConfig) (Server, error) {
|
||||
return newServer(config, localFilesystem)
|
||||
}
|
||||
|
||||
func newServer(config ServerConfig, fs *filesystem) (Server, error) {
|
||||
if config.Dialer == nil {
|
||||
config.Dialer = tcpDialer{}
|
||||
}
|
||||
|
||||
if config.Host == "" {
|
||||
config.Host = "localhost"
|
||||
}
|
||||
if config.Port == 0 {
|
||||
config.Port = AdbPort
|
||||
}
|
||||
|
||||
if config.PathToAdb == "" {
|
||||
path, err := fs.LookPath(AdbExecutableName)
|
||||
if err != nil {
|
||||
return nil, util.WrapErrorf(err, util.ServerNotAvailable, "could not find %s in PATH", AdbExecutableName)
|
||||
}
|
||||
config.PathToAdb = path
|
||||
}
|
||||
if err := fs.IsExecutableFile(config.PathToAdb); err != nil {
|
||||
return nil, util.WrapErrorf(err, util.ServerNotAvailable, "invalid adb executable: %s", config.PathToAdb)
|
||||
}
|
||||
|
||||
return &realServer{
|
||||
config: config,
|
||||
fs: fs,
|
||||
address: fmt.Sprintf("%s:%d", config.Host, config.Port),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Dial tries to connect to the server. If the first attempt fails, tries starting the server before
|
||||
// retrying. If the second attempt fails, returns the error.
|
||||
func (s *realServer) Dial() (*wire.Conn, error) {
|
||||
conn, err := s.config.Dial(s.address)
|
||||
if err != nil {
|
||||
// Attempt to start the server and try again.
|
||||
if err = s.Start(); err != nil {
|
||||
return nil, util.WrapErrorf(err, util.ServerNotAvailable, "error starting server for dial")
|
||||
}
|
||||
|
||||
conn, err = s.config.Dial(s.address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// StartServer ensures there is a server running.
|
||||
func (s *realServer) Start() error {
|
||||
output, err := s.fs.CmdCombinedOutput(s.config.PathToAdb, "start-server")
|
||||
outputStr := strings.TrimSpace(string(output))
|
||||
return util.WrapErrorf(err, util.ServerNotAvailable, "error starting server: %s\noutput:\n%s", err, outputStr)
|
||||
}
|
||||
|
||||
// filesystem abstracts interactions with the local filesystem for testability.
|
||||
type filesystem struct {
|
||||
// Wraps exec.LookPath.
|
||||
LookPath func(string) (string, error)
|
||||
|
||||
// Returns nil if path is a regular file and executable by the current user.
|
||||
IsExecutableFile func(path string) error
|
||||
|
||||
// Wraps exec.Command().CombinedOutput()
|
||||
CmdCombinedOutput func(name string, arg ...string) ([]byte, error)
|
||||
}
|
||||
|
||||
var localFilesystem = &filesystem{
|
||||
LookPath: exec.LookPath,
|
||||
IsExecutableFile: func(path string) error {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !info.Mode().IsRegular() {
|
||||
return errors.New("not a regular file")
|
||||
}
|
||||
return unix.Access(path, unix.X_OK)
|
||||
},
|
||||
CmdCombinedOutput: func(name string, arg ...string) ([]byte, error) {
|
||||
return exec.Command(name, arg...).CombinedOutput()
|
||||
},
|
||||
}
|
|
@ -1,18 +0,0 @@
|
|||
package goadb
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/zach-klippenstein/goadb/util"
|
||||
)
|
||||
|
||||
/*
|
||||
StartServer ensures there is a server running.
|
||||
*/
|
||||
func StartServer() error {
|
||||
cmd := exec.Command("adb", "start-server")
|
||||
output, err := cmd.CombinedOutput()
|
||||
outputStr := strings.TrimSpace(string(output))
|
||||
return util.WrapErrorf(err, util.ServerNotAvailable, "error starting server: %s\noutput:\n%s", err, outputStr)
|
||||
}
|
117
server_mock_test.go
Normal file
117
server_mock_test.go
Normal file
|
@ -0,0 +1,117 @@
|
|||
package goadb
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/zach-klippenstein/goadb/util"
|
||||
"github.com/zach-klippenstein/goadb/wire"
|
||||
)
|
||||
|
||||
// MockServer implements Server, Scanner, and Sender.
|
||||
type MockServer struct {
|
||||
// Each time an operation is performed, if this slice is non-empty, the head element
|
||||
// of this slice is returned and removed from the slice. If the head is nil, it is removed
|
||||
// but not returned.
|
||||
Errs []error
|
||||
|
||||
Status string
|
||||
|
||||
// Messages are returned from read calls in order, each preceded by a length header.
|
||||
Messages []string
|
||||
nextMsgIndex int
|
||||
|
||||
// Each message passed to a send call is appended to this slice.
|
||||
Requests []string
|
||||
|
||||
// Each time an operation is performed, its name is appended to this slice.
|
||||
Trace []string
|
||||
}
|
||||
|
||||
var _ Server = &MockServer{}
|
||||
|
||||
func (s *MockServer) Dial() (*wire.Conn, error) {
|
||||
s.logMethod("Dial")
|
||||
if err := s.getNextErrToReturn(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return wire.NewConn(s, s), nil
|
||||
}
|
||||
|
||||
func (s *MockServer) Start() error {
|
||||
s.logMethod("Start")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MockServer) ReadStatus(req string) (string, error) {
|
||||
s.logMethod("ReadStatus")
|
||||
if err := s.getNextErrToReturn(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return s.Status, nil
|
||||
}
|
||||
|
||||
func (s *MockServer) ReadMessage() ([]byte, error) {
|
||||
s.logMethod("ReadMessage")
|
||||
if err := s.getNextErrToReturn(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.nextMsgIndex >= len(s.Messages) {
|
||||
return nil, util.WrapErrorf(io.EOF, util.NetworkError, "")
|
||||
}
|
||||
|
||||
s.nextMsgIndex++
|
||||
return []byte(s.Messages[s.nextMsgIndex-1]), nil
|
||||
}
|
||||
|
||||
func (s *MockServer) ReadUntilEof() ([]byte, error) {
|
||||
s.logMethod("ReadUntilEof")
|
||||
if err := s.getNextErrToReturn(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var data []string
|
||||
for ; s.nextMsgIndex < len(s.Messages); s.nextMsgIndex++ {
|
||||
data = append(data, s.Messages[s.nextMsgIndex])
|
||||
}
|
||||
return []byte(strings.Join(data, "")), nil
|
||||
}
|
||||
|
||||
func (s *MockServer) SendMessage(msg []byte) error {
|
||||
s.logMethod("SendMessage")
|
||||
if err := s.getNextErrToReturn(); err != nil {
|
||||
return err
|
||||
}
|
||||
s.Requests = append(s.Requests, string(msg))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MockServer) NewSyncScanner() wire.SyncScanner {
|
||||
s.logMethod("NewSyncScanner")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MockServer) NewSyncSender() wire.SyncSender {
|
||||
s.logMethod("NewSyncSender")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MockServer) Close() error {
|
||||
s.logMethod("Close")
|
||||
if err := s.getNextErrToReturn(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MockServer) getNextErrToReturn() (err error) {
|
||||
if len(s.Errs) > 0 {
|
||||
err = s.Errs[0]
|
||||
s.Errs = s.Errs[1:]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *MockServer) logMethod(name string) {
|
||||
s.Trace = append(s.Trace, name)
|
||||
}
|
80
server_test.go
Normal file
80
server_test.go
Normal file
|
@ -0,0 +1,80 @@
|
|||
package goadb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zach-klippenstein/goadb/wire"
|
||||
)
|
||||
|
||||
func TestNewServer_ZeroConfig(t *testing.T) {
|
||||
config := ServerConfig{}
|
||||
fs := &filesystem{
|
||||
LookPath: func(name string) (string, error) {
|
||||
if name == AdbExecutableName {
|
||||
return "/bin/adb", nil
|
||||
}
|
||||
return "", fmt.Errorf("invalid name: %s", name)
|
||||
},
|
||||
IsExecutableFile: func(path string) error {
|
||||
if path == "/bin/adb" {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("wrong path: %s", path)
|
||||
},
|
||||
}
|
||||
|
||||
serverIf, err := newServer(config, fs)
|
||||
server := serverIf.(*realServer)
|
||||
assert.NoError(t, err)
|
||||
assert.IsType(t, tcpDialer{}, server.config.Dialer)
|
||||
assert.Equal(t, "localhost", server.config.Host)
|
||||
assert.Equal(t, AdbPort, server.config.Port)
|
||||
assert.Equal(t, fmt.Sprintf("localhost:%d", AdbPort), server.address)
|
||||
assert.Equal(t, "/bin/adb", server.config.PathToAdb)
|
||||
}
|
||||
|
||||
type MockDialer struct{}
|
||||
|
||||
func (d MockDialer) Dial(address string) (*wire.Conn, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestNewServer_CustomConfig(t *testing.T) {
|
||||
config := ServerConfig{
|
||||
Dialer: MockDialer{},
|
||||
Host: "foobar",
|
||||
Port: 1,
|
||||
PathToAdb: "/bin/adb",
|
||||
}
|
||||
fs := &filesystem{
|
||||
IsExecutableFile: func(path string) error {
|
||||
if path == "/bin/adb" {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("wrong path: %s", path)
|
||||
},
|
||||
}
|
||||
|
||||
serverIf, err := newServer(config, fs)
|
||||
server := serverIf.(*realServer)
|
||||
assert.NoError(t, err)
|
||||
assert.IsType(t, MockDialer{}, server.config.Dialer)
|
||||
assert.Equal(t, "foobar", server.config.Host)
|
||||
assert.Equal(t, 1, server.config.Port)
|
||||
assert.Equal(t, fmt.Sprintf("foobar:1"), server.address)
|
||||
assert.Equal(t, "/bin/adb", server.config.PathToAdb)
|
||||
}
|
||||
|
||||
func TestNewServer_AdbNotFound(t *testing.T) {
|
||||
config := ServerConfig{}
|
||||
fs := &filesystem{
|
||||
LookPath: func(name string) (string, error) {
|
||||
return "", fmt.Errorf("executable not found: %s", name)
|
||||
},
|
||||
}
|
||||
|
||||
_, err := newServer(config, fs)
|
||||
assert.EqualError(t, err, "ServerNotAvailable: could not find adb in PATH")
|
||||
}
|
Loading…
Reference in a new issue