Merge pull request #10 from zach-klippenstein/server

Refactored ClientConfig into a Server interface.
This commit is contained in:
Zach Klippenstein 2016-01-10 14:24:22 -08:00
commit ee01d1202e
15 changed files with 420 additions and 278 deletions

View file

@ -1,16 +0,0 @@
package goadb
var (
defaultDialer Dialer = NewDialer("", 0)
)
type ClientConfig struct {
Dialer Dialer
}
func (c ClientConfig) sanitized() ClientConfig {
if c.Dialer == nil {
c.Dialer = defaultDialer
}
return c
}

View file

@ -36,9 +36,18 @@ var (
pushRemoteArg = pushCommand.Arg("remote", "Path of destination file on device.").Required().String() pushRemoteArg = pushCommand.Arg("remote", "Path of destination file on device.").Required().String()
) )
var server goadb.Server
func main() { func main() {
var exitCode int var exitCode int
var err error
server, err = goadb.NewServer(goadb.ServerConfig{})
if err != nil {
fmt.Fprintln(os.Stderr, "error:", err)
os.Exit(1)
}
switch kingpin.Parse() { switch kingpin.Parse() {
case "devices": case "devices":
exitCode = listDevices(*devicesLongFlag) exitCode = listDevices(*devicesLongFlag)
@ -62,7 +71,7 @@ func parseDevice() goadb.DeviceDescriptor {
} }
func listDevices(long bool) int { func listDevices(long bool) int {
client := goadb.NewHostClient(goadb.ClientConfig{}) client := goadb.NewHostClient(server)
devices, err := client.ListDevices() devices, err := client.ListDevices()
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, "error:", err) fmt.Fprintln(os.Stderr, "error:", err)
@ -99,7 +108,7 @@ func runShellCommand(commandAndArgs []string, device goadb.DeviceDescriptor) int
args = commandAndArgs[1:] args = commandAndArgs[1:]
} }
client := goadb.NewDeviceClient(goadb.ClientConfig{}, device) client := goadb.NewDeviceClient(server, device)
output, err := client.RunCommand(command, args...) output, err := client.RunCommand(command, args...)
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, "error:", err) fmt.Fprintln(os.Stderr, "error:", err)
@ -121,7 +130,7 @@ func pull(showProgress bool, remotePath, localPath string, device goadb.DeviceDe
localPath = filepath.Base(remotePath) localPath = filepath.Base(remotePath)
} }
client := goadb.NewDeviceClient(goadb.ClientConfig{}, device) client := goadb.NewDeviceClient(server, device)
info, err := client.Stat(remotePath) info, err := client.Stat(remotePath)
if util.HasErrCode(err, util.FileNoExistError) { if util.HasErrCode(err, util.FileNoExistError) {
@ -194,7 +203,7 @@ func push(showProgress bool, localPath, remotePath string, device goadb.DeviceDe
} }
defer localFile.Close() defer localFile.Close()
client := goadb.NewDeviceClient(goadb.ClientConfig{}, device) client := goadb.NewDeviceClient(server, device)
writer, err := client.OpenWrite(remotePath, perms, mtime) writer, err := client.OpenWrite(remotePath, perms, mtime)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "error opening remote file %s: %s\n", remotePath, err) fmt.Fprintf(os.Stderr, "error opening remote file %s: %s\n", remotePath, err)

View file

@ -12,15 +12,26 @@ import (
"github.com/zach-klippenstein/goadb/util" "github.com/zach-klippenstein/goadb/util"
) )
var port = flag.Int("p", adb.AdbPort, "") var (
port = flag.Int("p", adb.AdbPort, "")
server adb.Server
)
func main() { func main() {
flag.Parse() flag.Parse()
client := adb.NewHostClient(adb.ClientConfig{}) var err error
server, err = adb.NewServer(adb.ServerConfig{
Port: *port,
})
if err != nil {
log.Fatal(err)
}
fmt.Println("Starting server…") fmt.Println("Starting server…")
adb.StartServer() server.Start()
client := adb.NewHostClient(server)
serverVersion, err := client.GetServerVersion() serverVersion, err := client.GetServerVersion()
if err != nil { if err != nil {
@ -51,7 +62,7 @@ func main() {
fmt.Println() fmt.Println()
fmt.Println("Watching for device state changes.") fmt.Println("Watching for device state changes.")
watcher := adb.NewDeviceWatcher(adb.ClientConfig{}) watcher := adb.NewDeviceWatcher(server)
for event := range watcher.C() { for event := range watcher.C() {
fmt.Printf("\t[%s]%+v\n", time.Now(), event) fmt.Printf("\t[%s]%+v\n", time.Now(), event)
} }
@ -77,7 +88,7 @@ func printErr(err error) {
} }
func PrintDeviceInfoAndError(descriptor adb.DeviceDescriptor) { func PrintDeviceInfoAndError(descriptor adb.DeviceDescriptor) {
device := adb.NewDeviceClient(adb.ClientConfig{}, descriptor) device := adb.NewDeviceClient(server, descriptor)
if err := PrintDeviceInfo(device); err != nil { if err := PrintDeviceInfo(device); err != nil {
log.Println(err) log.Println(err)
} }

View file

@ -49,7 +49,14 @@ func readLine() string {
} }
func doCommand(cmd string) error { func doCommand(cmd string) error {
conn, err := goadb.NewDialer("", *port).Dial() server, err := goadb.NewServer(goadb.ServerConfig{
Port: *port,
})
if err != nil {
log.Fatal(err)
}
conn, err := server.Dial()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View file

@ -17,18 +17,18 @@ var MtimeOfClose = time.Time{}
// DeviceClient communicates with a specific Android device. // DeviceClient communicates with a specific Android device.
type DeviceClient struct { type DeviceClient struct {
config ClientConfig server Server
descriptor DeviceDescriptor descriptor DeviceDescriptor
// Used to get device info. // Used to get device info.
deviceListFunc func() ([]*DeviceInfo, error) deviceListFunc func() ([]*DeviceInfo, error)
} }
func NewDeviceClient(config ClientConfig, descriptor DeviceDescriptor) *DeviceClient { func NewDeviceClient(server Server, descriptor DeviceDescriptor) *DeviceClient {
return &DeviceClient{ return &DeviceClient{
config: config.sanitized(), server: server,
descriptor: descriptor, descriptor: descriptor,
deviceListFunc: NewHostClient(config).ListDevices, deviceListFunc: NewHostClient(server).ListDevices,
} }
} }
@ -194,7 +194,7 @@ func (c *DeviceClient) OpenWrite(path string, perms os.FileMode, mtime time.Time
// getAttribute returns the first message returned by the server by running // getAttribute returns the first message returned by the server by running
// <host-prefix>:<attr>, where host-prefix is determined from the DeviceDescriptor. // <host-prefix>:<attr>, where host-prefix is determined from the DeviceDescriptor.
func (c *DeviceClient) getAttribute(attr string) (string, error) { func (c *DeviceClient) getAttribute(attr string) (string, error) {
resp, err := roundTripSingleResponse(c.config.Dialer, resp, err := roundTripSingleResponse(c.server,
fmt.Sprintf("%s:%s", c.descriptor.getHostPrefix(), attr)) fmt.Sprintf("%s:%s", c.descriptor.getHostPrefix(), attr))
if err != nil { if err != nil {
return "", err return "", err
@ -222,7 +222,7 @@ func (c *DeviceClient) getSyncConn() (*wire.SyncConn, error) {
// dialDevice switches the connection to communicate directly with the device // dialDevice switches the connection to communicate directly with the device
// by requesting the transport defined by the DeviceDescriptor. // by requesting the transport defined by the DeviceDescriptor.
func (c *DeviceClient) dialDevice() (*wire.Conn, error) { func (c *DeviceClient) dialDevice() (*wire.Conn, error) {
conn, err := c.config.Dialer.Dial() conn, err := c.server.Dial()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -13,12 +13,7 @@ func TestGetAttribute(t *testing.T) {
Status: wire.StatusSuccess, Status: wire.StatusSuccess,
Messages: []string{"value"}, Messages: []string{"value"},
} }
client := NewDeviceClient( client := NewDeviceClient(s, DeviceWithSerial("serial"))
ClientConfig{
Dialer: s,
},
DeviceWithSerial("serial"),
)
v, err := client.getAttribute("attr") v, err := client.getAttribute("attr")
assert.Equal(t, "host-serial:serial:attr", s.Requests[0]) assert.Equal(t, "host-serial:serial:attr", s.Requests[0])
@ -60,12 +55,10 @@ func TestGetDeviceInfo(t *testing.T) {
func newDeviceClientWithDeviceLister(serial string, deviceLister func() ([]*DeviceInfo, error)) *DeviceClient { func newDeviceClientWithDeviceLister(serial string, deviceLister func() ([]*DeviceInfo, error)) *DeviceClient {
client := NewDeviceClient( client := NewDeviceClient(
ClientConfig{ &MockServer{
Dialer: &MockServer{
Status: wire.StatusSuccess, Status: wire.StatusSuccess,
Messages: []string{serial}, Messages: []string{serial},
}, },
},
DeviceWithSerial(serial), DeviceWithSerial(serial),
) )
client.deviceListFunc = deviceLister client.deviceListFunc = deviceLister
@ -77,12 +70,7 @@ func TestRunCommandNoArgs(t *testing.T) {
Status: wire.StatusSuccess, Status: wire.StatusSuccess,
Messages: []string{"output"}, Messages: []string{"output"},
} }
client := NewDeviceClient( client := NewDeviceClient(s, AnyDevice())
ClientConfig{
Dialer: s,
},
AnyDevice(),
)
v, err := client.RunCommand("cmd") v, err := client.RunCommand("cmd")
assert.Equal(t, "host:transport-any", s.Requests[0]) assert.Equal(t, "host:transport-any", s.Requests[0])

View file

@ -2,10 +2,10 @@ package goadb
import ( import (
"log" "log"
"math/rand"
"runtime" "runtime"
"strings" "strings"
"sync/atomic" "sync/atomic"
"math/rand"
"time" "time"
"github.com/zach-klippenstein/goadb/util" "github.com/zach-klippenstein/goadb/util"
@ -59,22 +59,18 @@ var deviceStateStrings = map[string]DeviceState{
} }
type deviceWatcherImpl struct { type deviceWatcherImpl struct {
config ClientConfig server Server
// If an error occurs, it is stored here and eventChan is close immediately after. // If an error occurs, it is stored here and eventChan is close immediately after.
err atomic.Value err atomic.Value
eventChan chan DeviceStateChangedEvent eventChan chan DeviceStateChangedEvent
// Function to start the server if it's not running or dies.
startServer func() error
} }
func NewDeviceWatcher(config ClientConfig) *DeviceWatcher { func NewDeviceWatcher(server Server) *DeviceWatcher {
watcher := &DeviceWatcher{&deviceWatcherImpl{ watcher := &DeviceWatcher{&deviceWatcherImpl{
config: config.sanitized(), server: server,
eventChan: make(chan DeviceStateChangedEvent), eventChan: make(chan DeviceStateChangedEvent),
startServer: StartServer,
}} }}
runtime.SetFinalizer(watcher, func(watcher *DeviceWatcher) { runtime.SetFinalizer(watcher, func(watcher *DeviceWatcher) {
@ -134,7 +130,7 @@ func publishDevices(watcher *deviceWatcherImpl) {
finished := false finished := false
for { for {
scanner, err := connectToTrackDevices(watcher.config.Dialer) scanner, err := connectToTrackDevices(watcher.server)
if err != nil { if err != nil {
watcher.reportErr(err) watcher.reportErr(err)
return return
@ -156,7 +152,7 @@ func publishDevices(watcher *deviceWatcherImpl) {
log.Printf("[DeviceWatcher] server died, restarting in %s…", delay) log.Printf("[DeviceWatcher] server died, restarting in %s…", delay)
time.Sleep(delay) time.Sleep(delay)
if err := watcher.startServer(); err != nil { if err := watcher.server.Start(); err != nil {
log.Println("[DeviceWatcher] error restarting server, giving up") log.Println("[DeviceWatcher] error restarting server, giving up")
watcher.reportErr(err) watcher.reportErr(err)
return return
@ -169,8 +165,8 @@ func publishDevices(watcher *deviceWatcherImpl) {
} }
} }
func connectToTrackDevices(dialer Dialer) (wire.Scanner, error) { func connectToTrackDevices(server Server) (wire.Scanner, error) {
conn, err := dialer.Dial() conn, err := server.Dial()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1,7 +1,6 @@
package goadb package goadb
import ( import (
"log"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -207,8 +206,7 @@ func TestWentOffline(t *testing.T) {
} }
func TestPublishDevicesRestartsServer(t *testing.T) { func TestPublishDevicesRestartsServer(t *testing.T) {
starter := &MockServerStarter{} server := &MockServer{
dialer := &MockServer{
Status: wire.StatusSuccess, Status: wire.StatusSuccess,
Errs: []error{ Errs: []error{
nil, nil, nil, // Successful dial. nil, nil, nil, // Successful dial.
@ -217,34 +215,17 @@ func TestPublishDevicesRestartsServer(t *testing.T) {
}, },
} }
watcher := deviceWatcherImpl{ watcher := deviceWatcherImpl{
config: ClientConfig{dialer}, server: server,
eventChan: make(chan DeviceStateChangedEvent), eventChan: make(chan DeviceStateChangedEvent),
startServer: starter.StartServer,
} }
publishDevices(&watcher) publishDevices(&watcher)
assert.Empty(t, dialer.Errs) assert.Empty(t, server.Errs)
assert.Equal(t, []string{"host:track-devices"}, dialer.Requests) assert.Equal(t, []string{"host:track-devices"}, server.Requests)
assert.Equal(t, []string{"Dial", "SendMessage", "ReadStatus", "ReadMessage", "Dial"}, dialer.Trace) assert.Equal(t, []string{"Dial", "SendMessage", "ReadStatus", "ReadMessage", "Start", "Dial"}, server.Trace)
err := watcher.err.Load().(*util.Err) err := watcher.err.Load().(*util.Err)
assert.Equal(t, util.ServerNotAvailable, err.Code) assert.Equal(t, util.ServerNotAvailable, err.Code)
assert.Equal(t, 1, starter.startCount)
}
type MockServerStarter struct {
startCount int
err error
}
func (s *MockServerStarter) StartServer() error {
log.Printf("Starting mock server")
if s.err == nil {
s.startCount += 1
return nil
} else {
return s.err
}
} }
func assertContainsOnly(t *testing.T, expected, actual []DeviceStateChangedEvent) { func assertContainsOnly(t *testing.T, expected, actual []DeviceStateChangedEvent) {

View file

@ -1,7 +1,6 @@
package goadb package goadb
import ( import (
"fmt"
"io" "io"
"net" "net"
"runtime" "runtime"
@ -10,62 +9,20 @@ import (
"github.com/zach-klippenstein/goadb/wire" "github.com/zach-klippenstein/goadb/wire"
) )
const ( // Dialer knows how to create connections to an adb server.
// Default port the adb server listens on.
AdbPort = 5037
)
/*
Dialer knows how to create connections to an adb server.
*/
type Dialer interface { type Dialer interface {
Dial() (*wire.Conn, error) Dial(address string) (*wire.Conn, error)
} }
/* type tcpDialer struct{}
NewDialer creates a new Dialer.
If host is "" or port is 0, "localhost:5037" is used.
*/
func NewDialer(host string, port int) Dialer {
if host == "" {
host = "localhost"
}
if port == 0 {
port = AdbPort
}
return &netDialer{host, port}
}
type netDialer struct {
Host string
Port int
}
func (d *netDialer) String() string {
return fmt.Sprintf("netDialer(%s:%d)", d.Host, d.Port)
}
// Dial connects to the adb server on the host and port set on the netDialer. // Dial connects to the adb server on the host and port set on the netDialer.
// The zero-value will connect to the default, localhost:5037. // The zero-value will connect to the default, localhost:5037.
func (d *netDialer) Dial() (*wire.Conn, error) { func (tcpDialer) Dial(address string) (*wire.Conn, error) {
host := d.Host
port := d.Port
address := fmt.Sprintf("%s:%d", host, port)
netConn, err := net.Dial("tcp", address) netConn, err := net.Dial("tcp", address)
if err != nil {
// Attempt to start the server and try again.
if err = StartServer(); err != nil {
return nil, util.WrapErrorf(err, util.ServerNotAvailable, "error starting server")
}
address = fmt.Sprintf("%s:%d", host, port)
netConn, err = net.Dial("tcp", address)
if err != nil { if err != nil {
return nil, util.WrapErrorf(err, util.ServerNotAvailable, "error dialing %s", address) return nil, util.WrapErrorf(err, util.ServerNotAvailable, "error dialing %s", address)
} }
}
// net.Conn can't be closed more than once, but wire.Conn will try to close both sender and scanner // net.Conn can't be closed more than once, but wire.Conn will try to close both sender and scanner
// so we need to wrap it to make it safe. // so we need to wrap it to make it safe.
@ -84,13 +41,3 @@ func (d *netDialer) Dial() (*wire.Conn, error) {
Sender: wire.NewSender(safeConn), Sender: wire.NewSender(safeConn),
}, nil }, nil
} }
func roundTripSingleResponse(d Dialer, req string) ([]byte, error) {
conn, err := d.Dial()
if err != nil {
return nil, err
}
defer conn.Close()
return conn.RoundTripSingleResponse([]byte(req))
}

View file

@ -19,16 +19,16 @@ See list of services at https://android.googlesource.com/platform/system/core/+/
*/ */
// TODO(z): Finish implementing host services. // TODO(z): Finish implementing host services.
type HostClient struct { type HostClient struct {
config ClientConfig server Server
} }
func NewHostClient(config ClientConfig) *HostClient { func NewHostClient(server Server) *HostClient {
return &HostClient{config.sanitized()} return &HostClient{server}
} }
// GetServerVersion asks the ADB server for its internal version number. // GetServerVersion asks the ADB server for its internal version number.
func (c *HostClient) GetServerVersion() (int, error) { func (c *HostClient) GetServerVersion() (int, error) {
resp, err := roundTripSingleResponse(c.config.Dialer, "host:version") resp, err := roundTripSingleResponse(c.server, "host:version")
if err != nil { if err != nil {
return 0, wrapClientError(err, c, "GetServerVersion") return 0, wrapClientError(err, c, "GetServerVersion")
} }
@ -47,7 +47,7 @@ Corresponds to the command:
adb kill-server adb kill-server
*/ */
func (c *HostClient) KillServer() error { func (c *HostClient) KillServer() error {
conn, err := c.config.Dialer.Dial() conn, err := c.server.Dial()
if err != nil { if err != nil {
return wrapClientError(err, c, "KillServer") return wrapClientError(err, c, "KillServer")
} }
@ -67,7 +67,7 @@ Corresponds to the command:
adb devices adb devices
*/ */
func (c *HostClient) ListDeviceSerials() ([]string, error) { func (c *HostClient) ListDeviceSerials() ([]string, error) {
resp, err := roundTripSingleResponse(c.config.Dialer, "host:devices") resp, err := roundTripSingleResponse(c.server, "host:devices")
if err != nil { if err != nil {
return nil, wrapClientError(err, c, "ListDeviceSerials") return nil, wrapClientError(err, c, "ListDeviceSerials")
} }
@ -91,7 +91,7 @@ Corresponds to the command:
adb devices -l adb devices -l
*/ */
func (c *HostClient) ListDevices() ([]*DeviceInfo, error) { func (c *HostClient) ListDevices() ([]*DeviceInfo, error) {
resp, err := roundTripSingleResponse(c.config.Dialer, "host:devices-l") resp, err := roundTripSingleResponse(c.server, "host:devices-l")
if err != nil { if err != nil {
return nil, wrapClientError(err, c, "ListDevices") return nil, wrapClientError(err, c, "ListDevices")
} }

View file

@ -1,12 +1,9 @@
package goadb package goadb
import ( import (
"io"
"strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zach-klippenstein/goadb/util"
"github.com/zach-klippenstein/goadb/wire" "github.com/zach-klippenstein/goadb/wire"
) )
@ -15,113 +12,10 @@ func TestGetServerVersion(t *testing.T) {
Status: wire.StatusSuccess, Status: wire.StatusSuccess,
Messages: []string{"000a"}, Messages: []string{"000a"},
} }
client := NewHostClient(ClientConfig{ client := NewHostClient(s)
Dialer: s,
})
v, err := client.GetServerVersion() v, err := client.GetServerVersion()
assert.Equal(t, "host:version", s.Requests[0]) assert.Equal(t, "host:version", s.Requests[0])
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 10, v) assert.Equal(t, 10, v)
} }
// MockServer implements Dialer, Scanner, and Sender.
type MockServer struct {
// Each time an operation is performed, if this slice is non-empty, the head element
// of this slice is returned and removed from the slice. If the head is nil, it is removed
// but not returned.
Errs []error
Status string
// Messages are returned from read calls in order, each preceded by a length header.
Messages []string
nextMsgIndex int
// Each message passed to a send call is appended to this slice.
Requests []string
// Each time an operaiton is performed, its name is appended to this slice.
Trace []string
}
func (s *MockServer) Dial() (*wire.Conn, error) {
s.logMethod("Dial")
if err := s.getNextErrToReturn(); err != nil {
return nil, err
}
return wire.NewConn(s, s), nil
}
func (s *MockServer) ReadStatus(req string) (string, error) {
s.logMethod("ReadStatus")
if err := s.getNextErrToReturn(); err != nil {
return "", err
}
return s.Status, nil
}
func (s *MockServer) ReadMessage() ([]byte, error) {
s.logMethod("ReadMessage")
if err := s.getNextErrToReturn(); err != nil {
return nil, err
}
if s.nextMsgIndex >= len(s.Messages) {
return nil, util.WrapErrorf(io.EOF, util.NetworkError, "")
}
s.nextMsgIndex++
return []byte(s.Messages[s.nextMsgIndex-1]), nil
}
func (s *MockServer) ReadUntilEof() ([]byte, error) {
s.logMethod("ReadUntilEof")
if err := s.getNextErrToReturn(); err != nil {
return nil, err
}
var data []string
for ; s.nextMsgIndex < len(s.Messages); s.nextMsgIndex++ {
data = append(data, s.Messages[s.nextMsgIndex])
}
return []byte(strings.Join(data, "")), nil
}
func (s *MockServer) SendMessage(msg []byte) error {
s.logMethod("SendMessage")
if err := s.getNextErrToReturn(); err != nil {
return err
}
s.Requests = append(s.Requests, string(msg))
return nil
}
func (s *MockServer) NewSyncScanner() wire.SyncScanner {
s.logMethod("NewSyncScanner")
return nil
}
func (s *MockServer) NewSyncSender() wire.SyncSender {
s.logMethod("NewSyncSender")
return nil
}
func (s *MockServer) Close() error {
s.logMethod("Close")
if err := s.getNextErrToReturn(); err != nil {
return err
}
return nil
}
func (s *MockServer) getNextErrToReturn() (err error) {
if len(s.Errs) > 0 {
err = s.Errs[0]
s.Errs = s.Errs[1:]
}
return
}
func (s *MockServer) logMethod(name string) {
s.Trace = append(s.Trace, name)
}

146
server.go Normal file
View file

@ -0,0 +1,146 @@
package goadb
import (
"errors"
"fmt"
"os"
"os/exec"
"strings"
"github.com/zach-klippenstein/goadb/util"
"github.com/zach-klippenstein/goadb/wire"
"golang.org/x/sys/unix"
)
const (
AdbExecutableName = "adb"
// Default port the adb server listens on.
AdbPort = 5037
)
type ServerConfig struct {
// Path to the adb executable. If empty, the PATH environment variable will be searched.
PathToAdb string
// Host and port the adb server is listening on.
// If not specified, will use the default port on localhost.
Host string
Port int
// Dialer used to connect to the adb server.
Dialer
}
// Server knows how to start the adb server and connect to it.
type Server interface {
Start() error
Dial() (*wire.Conn, error)
}
func roundTripSingleResponse(s Server, req string) ([]byte, error) {
conn, err := s.Dial()
if err != nil {
return nil, err
}
defer conn.Close()
return conn.RoundTripSingleResponse([]byte(req))
}
type realServer struct {
config ServerConfig
fs *filesystem
// Caches Host:Port so they don't have to be concatenated for every dial.
address string
}
// NewServer creates a new Server instance.
func NewServer(config ServerConfig) (Server, error) {
return newServer(config, localFilesystem)
}
func newServer(config ServerConfig, fs *filesystem) (Server, error) {
if config.Dialer == nil {
config.Dialer = tcpDialer{}
}
if config.Host == "" {
config.Host = "localhost"
}
if config.Port == 0 {
config.Port = AdbPort
}
if config.PathToAdb == "" {
path, err := fs.LookPath(AdbExecutableName)
if err != nil {
return nil, util.WrapErrorf(err, util.ServerNotAvailable, "could not find %s in PATH", AdbExecutableName)
}
config.PathToAdb = path
}
if err := fs.IsExecutableFile(config.PathToAdb); err != nil {
return nil, util.WrapErrorf(err, util.ServerNotAvailable, "invalid adb executable: %s", config.PathToAdb)
}
return &realServer{
config: config,
fs: fs,
address: fmt.Sprintf("%s:%d", config.Host, config.Port),
}, nil
}
// Dial tries to connect to the server. If the first attempt fails, tries starting the server before
// retrying. If the second attempt fails, returns the error.
func (s *realServer) Dial() (*wire.Conn, error) {
conn, err := s.config.Dial(s.address)
if err != nil {
// Attempt to start the server and try again.
if err = s.Start(); err != nil {
return nil, util.WrapErrorf(err, util.ServerNotAvailable, "error starting server for dial")
}
conn, err = s.config.Dial(s.address)
if err != nil {
return nil, err
}
}
return conn, nil
}
// StartServer ensures there is a server running.
func (s *realServer) Start() error {
output, err := s.fs.CmdCombinedOutput(s.config.PathToAdb, "start-server")
outputStr := strings.TrimSpace(string(output))
return util.WrapErrorf(err, util.ServerNotAvailable, "error starting server: %s\noutput:\n%s", err, outputStr)
}
// filesystem abstracts interactions with the local filesystem for testability.
type filesystem struct {
// Wraps exec.LookPath.
LookPath func(string) (string, error)
// Returns nil if path is a regular file and executable by the current user.
IsExecutableFile func(path string) error
// Wraps exec.Command().CombinedOutput()
CmdCombinedOutput func(name string, arg ...string) ([]byte, error)
}
var localFilesystem = &filesystem{
LookPath: exec.LookPath,
IsExecutableFile: func(path string) error {
info, err := os.Stat(path)
if err != nil {
return err
}
if !info.Mode().IsRegular() {
return errors.New("not a regular file")
}
return unix.Access(path, unix.X_OK)
},
CmdCombinedOutput: func(name string, arg ...string) ([]byte, error) {
return exec.Command(name, arg...).CombinedOutput()
},
}

View file

@ -1,18 +0,0 @@
package goadb
import (
"os/exec"
"strings"
"github.com/zach-klippenstein/goadb/util"
)
/*
StartServer ensures there is a server running.
*/
func StartServer() error {
cmd := exec.Command("adb", "start-server")
output, err := cmd.CombinedOutput()
outputStr := strings.TrimSpace(string(output))
return util.WrapErrorf(err, util.ServerNotAvailable, "error starting server: %s\noutput:\n%s", err, outputStr)
}

117
server_mock_test.go Normal file
View file

@ -0,0 +1,117 @@
package goadb
import (
"io"
"strings"
"github.com/zach-klippenstein/goadb/util"
"github.com/zach-klippenstein/goadb/wire"
)
// MockServer implements Server, Scanner, and Sender.
type MockServer struct {
// Each time an operation is performed, if this slice is non-empty, the head element
// of this slice is returned and removed from the slice. If the head is nil, it is removed
// but not returned.
Errs []error
Status string
// Messages are returned from read calls in order, each preceded by a length header.
Messages []string
nextMsgIndex int
// Each message passed to a send call is appended to this slice.
Requests []string
// Each time an operation is performed, its name is appended to this slice.
Trace []string
}
var _ Server = &MockServer{}
func (s *MockServer) Dial() (*wire.Conn, error) {
s.logMethod("Dial")
if err := s.getNextErrToReturn(); err != nil {
return nil, err
}
return wire.NewConn(s, s), nil
}
func (s *MockServer) Start() error {
s.logMethod("Start")
return nil
}
func (s *MockServer) ReadStatus(req string) (string, error) {
s.logMethod("ReadStatus")
if err := s.getNextErrToReturn(); err != nil {
return "", err
}
return s.Status, nil
}
func (s *MockServer) ReadMessage() ([]byte, error) {
s.logMethod("ReadMessage")
if err := s.getNextErrToReturn(); err != nil {
return nil, err
}
if s.nextMsgIndex >= len(s.Messages) {
return nil, util.WrapErrorf(io.EOF, util.NetworkError, "")
}
s.nextMsgIndex++
return []byte(s.Messages[s.nextMsgIndex-1]), nil
}
func (s *MockServer) ReadUntilEof() ([]byte, error) {
s.logMethod("ReadUntilEof")
if err := s.getNextErrToReturn(); err != nil {
return nil, err
}
var data []string
for ; s.nextMsgIndex < len(s.Messages); s.nextMsgIndex++ {
data = append(data, s.Messages[s.nextMsgIndex])
}
return []byte(strings.Join(data, "")), nil
}
func (s *MockServer) SendMessage(msg []byte) error {
s.logMethod("SendMessage")
if err := s.getNextErrToReturn(); err != nil {
return err
}
s.Requests = append(s.Requests, string(msg))
return nil
}
func (s *MockServer) NewSyncScanner() wire.SyncScanner {
s.logMethod("NewSyncScanner")
return nil
}
func (s *MockServer) NewSyncSender() wire.SyncSender {
s.logMethod("NewSyncSender")
return nil
}
func (s *MockServer) Close() error {
s.logMethod("Close")
if err := s.getNextErrToReturn(); err != nil {
return err
}
return nil
}
func (s *MockServer) getNextErrToReturn() (err error) {
if len(s.Errs) > 0 {
err = s.Errs[0]
s.Errs = s.Errs[1:]
}
return
}
func (s *MockServer) logMethod(name string) {
s.Trace = append(s.Trace, name)
}

80
server_test.go Normal file
View file

@ -0,0 +1,80 @@
package goadb
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zach-klippenstein/goadb/wire"
)
func TestNewServer_ZeroConfig(t *testing.T) {
config := ServerConfig{}
fs := &filesystem{
LookPath: func(name string) (string, error) {
if name == AdbExecutableName {
return "/bin/adb", nil
}
return "", fmt.Errorf("invalid name: %s", name)
},
IsExecutableFile: func(path string) error {
if path == "/bin/adb" {
return nil
}
return fmt.Errorf("wrong path: %s", path)
},
}
serverIf, err := newServer(config, fs)
server := serverIf.(*realServer)
assert.NoError(t, err)
assert.IsType(t, tcpDialer{}, server.config.Dialer)
assert.Equal(t, "localhost", server.config.Host)
assert.Equal(t, AdbPort, server.config.Port)
assert.Equal(t, fmt.Sprintf("localhost:%d", AdbPort), server.address)
assert.Equal(t, "/bin/adb", server.config.PathToAdb)
}
type MockDialer struct{}
func (d MockDialer) Dial(address string) (*wire.Conn, error) {
return nil, nil
}
func TestNewServer_CustomConfig(t *testing.T) {
config := ServerConfig{
Dialer: MockDialer{},
Host: "foobar",
Port: 1,
PathToAdb: "/bin/adb",
}
fs := &filesystem{
IsExecutableFile: func(path string) error {
if path == "/bin/adb" {
return nil
}
return fmt.Errorf("wrong path: %s", path)
},
}
serverIf, err := newServer(config, fs)
server := serverIf.(*realServer)
assert.NoError(t, err)
assert.IsType(t, MockDialer{}, server.config.Dialer)
assert.Equal(t, "foobar", server.config.Host)
assert.Equal(t, 1, server.config.Port)
assert.Equal(t, fmt.Sprintf("foobar:1"), server.address)
assert.Equal(t, "/bin/adb", server.config.PathToAdb)
}
func TestNewServer_AdbNotFound(t *testing.T) {
config := ServerConfig{}
fs := &filesystem{
LookPath: func(name string) (string, error) {
return "", fmt.Errorf("executable not found: %s", name)
},
}
_, err := newServer(config, fs)
assert.EqualError(t, err, "ServerNotAvailable: could not find adb in PATH")
}