package connection
import (
"encoding/binary"
"errors"
"io"
"net"
"sync"
"time"
)
const maxLength int = 1<<32 - 1
var (
rHeadBytes = [4]byte{0,0,0}
wHeadBytes = [4]byte{0,0}
errMsgRead = errors.New("Message read length error")
errHeadLen = errors.New("Message head length error")
errMsgLen = errors.New("Message length is no longer in normal range")
)
var connPool sync.Pool
func Newconnection(conn net.Conn) Conn {
c := connPool.Get()
if cnt,ok := c.(*connection); ok {
cnt.rwc = conn
return cnt
}
return &connection{rlen: 0,rwc: conn}
}
type Conn interface {
Read() (r io.Reader,size int,err error)
Write(p []byte) (n int,err error)
Writer(size int,r io.Reader) (n int64,err error)
RemoteAddr() net.Addr
LocalAddr() net.Addr
SetDeadline(t time.Time) error
SetReadDeadline(t time.Time) error
SetWriteDeadline(t time.Time) error
Close() (err error)
}
type connection struct {
rlen int
rwc net.Conn
rlock sync.Mutex
wlock sync.Mutex
}
func (self *connection) rhead() error {
n,err := self.rwc.Read(rHeadBytes[:])
if n != 4 || err != nil {
if err != nil {
return err
}
return errHeadLen
}
self.rlen = int(binary.BigEndian.Uint32(rHeadBytes[:]))
return nil
}
func (self *connection) whead(l int) error {
if l <= 0 || l > maxLength {
return errMsgLen
}
binary.BigEndian.PutUint32(wHeadBytes[:],uint32(l))
_,err := self.rwc.Write(wHeadBytes[:])
return err
}
func (self *connection) Read() (r io.Reader,err error) {
self.rlock.Lock()
if err = self.rhead(); err != nil {
self.rlock.Unlock()
return
}
size = self.rlen
r = limitRead{r: io.LimitReader(self.rwc,int64(size)),unlock: self.rlock.Unlock}
return
}
func (self *connection) Write(p []byte) (n int,err error) {
self.wlock.Lock()
err = self.whead(len(p))
if err != nil {
self.wlock.Unlock()
return
}
n,err = self.rwc.Write(p)
self.wlock.Unlock()
return
}
func (self *connection) Writer(size int,err error) {
self.wlock.Lock()
err = self.whead(int(size))
if err != nil {
self.wlock.Unlock()
return
}
n,err = io.CopyN(self.rwc,r,int64(size))
self.wlock.Unlock()
return
}
func (self *connection) RemoteAddr() net.Addr {
return self.rwc.RemoteAddr()
}
func (self *connection) LocalAddr() net.Addr {
return self.rwc.LocalAddr()
}
func (self *connection) SetDeadline(t time.Time) error {
return self.rwc.SetDeadline(t)
}
func (self *connection) SetReadDeadline(t time.Time) error {
return self.rwc.SetReadDeadline(t)
}
func (self *connection) SetWriteDeadline(t time.Time) error {
return self.rwc.SetWriteDeadline(t)
}
func (self *connection) Close() (err error) {
err = self.rwc.Close()
self.rlen = 0
connPool.Put(self)
return
}
type limitRead struct {
r io.Reader
unlock func()
}
func (self limitRead) Read(p []byte) (n int,err error) {
n,err = self.r.Read(p)
if err != nil {
self.unlock()
}
return n,err
}
测试函数方法:
package connection
import (
"fmt"
"io"
"net"
"os"
"testing"
"time"
)
func Test_conn(t *testing.T) {
go Listener("tcp",":1789")
time.Sleep(1e9)
Dial()
}
func Dial() {
conn,err := net.Dial("tcp","127.0.0.1:1789")
if err != nil {
fmt.Println(err)
return
}
c := Newconnection(conn)
defer c.Close()
c.Write([]byte("Test"))
c.Write([]byte("Test"))
r,size,err := c.Read()
if err != nil {
fmt.Println(err,size)
return
}
_,err = io.Copy(os.Stdout,r)
if err != nil && err != io.EOF {
fmt.Println(err)
}
}
func Listener(proto,addr string) {
lis,err := net.Listen(proto,addr)
if err != nil {
panic("Listen port error:" + err.Error())
return
}
defer lis.Close()
for {
conn,err := lis.Accept()
if err != nil {
time.Sleep(1e7)
continue
}
go handler(conn)
}
}
func handler(conn net.Conn) {
c := Newconnection(conn)
msgchan := make(chan struct{})
defer c.Close()
go func(ch chan struct{}) {
<-msgchan
f,_ := os.Open("tcp_test.go")
defer f.Close()
info,_ := f.Stat()
c.Writer(int(info.Size()),f)
c.Close()
}(msgchan)
for {
r,err := c.Read()
if err != nil {
fmt.Println(err)
return
}
n,err := io.Copy(os.Stdout,r)
if err != nil || n != int64(size) {
if err == io.EOF {
continue
}
fmt.Println("读取数据失败:",err)
return
}
time.Sleep(2e9)
msgchan <- struct{}{}
}
}