diff --git a/snmpserver.go b/snmpserver.go index 2d21860..faa1a62 100644 --- a/snmpserver.go +++ b/snmpserver.go @@ -1,13 +1,19 @@ package GoSNMPServer -import "net" -import "github.com/pkg/errors" -import "reflect" +import ( + "net" + "reflect" + "strings" + "syscall" + + "github.com/pkg/errors" +) type SNMPServer struct { - wconnStream ISnmpServerListener - master MasterAgent - logger ILogger + udpConnStream ISnmpServerListener + tcpListener net.Listener + master MasterAgent + logger ILogger } func NewSNMPServer(master MasterAgent) *SNMPServer { @@ -20,8 +26,41 @@ func NewSNMPServer(master MasterAgent) *SNMPServer { return ret } +func (server *SNMPServer) Listen(l3proto, address string) error { + + if strings.EqualFold(l3proto, "udp") { + return server.ListenUDP(l3proto, address) + } else if strings.EqualFold(l3proto, "tcp") || strings.EqualFold(l3proto, "both") { + err := server.ListenUDP("udp", address) + if err != nil { + return err + } else { + return server.ListenTCP(address) + } + + } else { + return errors.New("Unknown l3proto") + } + +} + +func (server *SNMPServer) ListenTCP(address string) error { + + if server.tcpListener != nil { + return errors.New("Listened TCP") + } + var err error + list, err := net.Listen("tcp", address) + if err != nil { + return errors.Wrap(err, "TCP Listen Error") + } + server.tcpListener = list + server.logger.Infof("ListenTCP: address=%s", address) + return nil +} + func (server *SNMPServer) ListenUDP(l3proto, address string) error { - if server.wconnStream != nil { + if server.udpConnStream != nil { return errors.New("Listened") } i, err := NewUDPListener(l3proto, address) @@ -30,26 +69,35 @@ func (server *SNMPServer) ListenUDP(l3proto, address string) error { } server.logger.Infof("ListenUDP: l3proto=%s, address=%s", l3proto, address) i.SetupLogger(server.logger) - server.wconnStream = i + server.udpConnStream = i return nil } func (server *SNMPServer) Address() net.Addr { - return server.wconnStream.Address() + return server.udpConnStream.Address() } func (server *SNMPServer) Shutdown() { server.logger.Infof("Shutdown server") - if server.wconnStream != nil { - server.wconnStream.Shutdown() + if server.udpConnStream != nil { + server.udpConnStream.Shutdown() + } + + if server.tcpListener != nil { + server.tcpListener.Close() + server.tcpListener = nil } + } func (server *SNMPServer) ServeForever() error { - if server.wconnStream == nil { + if server.udpConnStream == nil { return errors.New("Not Listen") } + if server.tcpListener != nil { + go server.ServeForEverTCP() + } for { err := server.ServeNextRequest() if err != nil { @@ -65,6 +113,54 @@ func (server *SNMPServer) ServeForever() error { } } +func (server *SNMPServer) ServeForEverTCP() error { + + for { + conn, err := server.tcpListener.Accept() + if err != nil { + // Check if the error is due to closing the listener + if opErr, ok := err.(*net.OpError); ok && opErr.Op == "accept" && opErr.Net == "tcp" && opErr.Err == syscall.EINVAL { + server.logger.Errorf("ServeForEverTCP: listener closing accept error %v [type %v]", err, reflect.TypeOf(err)) + break // Exit the loop + } + + server.logger.Errorf("ServeForEverTCP: listener accept error %v [type %v]", err, reflect.TypeOf(err)) + continue // Continue accepting other connections + } + + go server.handleTCPConnection(conn) // Handle each connection in a new goroutine + } + return nil +} + +func (server *SNMPServer) handleTCPConnection(conn net.Conn) (err error) { + server.logger.Infof("tcp connection from %v", conn.RemoteAddr()) + defer conn.Close() + + for { + var msg [4096]byte + counts, err := conn.Read(msg[:]) + if err != nil { + server.logger.Errorf("tcp read error: %v", err) + break + } + server.logger.Infof("tcp request from %v. size=%v", conn.RemoteAddr(), counts) + result, err := server.master.ResponseForBuffer(msg[:counts]) + if err != nil { + server.logger.Errorf("ResponseForBuffer Error: %v. %s result", err, result) + break + } + + _, err = conn.Write(result) + if err != nil { + server.logger.Errorf("ResponseForBuffer Error: %v.", err) + break + + } + } + return nil +} + func (server *SNMPServer) ServeNextRequest() (err error) { defer func() { if err := recover(); err != nil { @@ -78,7 +174,7 @@ func (server *SNMPServer) ServeNextRequest() (err error) { return } }() - bytePDU, replyer, err := server.wconnStream.NextSnmp() + bytePDU, replyer, err := server.udpConnStream.NextSnmp() if err != nil { return err }