Simplify listener code

This commit is contained in:
NXShock 2025-06-03 11:40:52 +05:00
parent c4c9bff7c6
commit e4f9a1a130

View file

@ -31,7 +31,7 @@ func (h HostMapping) Add(host, outputUrlStr string) error {
case "file": case "file":
server := http.Server{Handler: http.FileServer(http.Dir(outputUrl.Path))} server := http.Server{Handler: http.FileServer(http.Dir(outputUrl.Path))}
go server.Serve(pd.listener) go server.Serve(pd.listener)
case "tcp": case "tcp", "unix":
go func(pd ProxyDirection) { go func(pd ProxyDirection) {
for { for {
conn, err := pd.listener.Accept() conn, err := pd.listener.Accept()
@ -39,18 +39,7 @@ func (h HostMapping) Add(host, outputUrlStr string) error {
slog.Debug(err.Error()) slog.Debug(err.Error())
continue continue
} }
go handleTcp(conn.(*tls.Conn), pd.Output) go handleListener(conn.(*tls.Conn), pd.Output)
}
}(pd)
case "unix":
go func(pd ProxyDirection) {
for {
conn, err := pd.listener.Accept()
if err != nil {
slog.Debug(err.Error())
continue
}
go handleUnix(conn.(*tls.Conn), pd.Output)
} }
}(pd) }(pd)
default: default:
@ -79,41 +68,12 @@ func handleTlsConn(conn *tls.Conn, hosts HostMapping) error {
return nil return nil
} }
func handleTcp(conn *tls.Conn, outputUrl *url.URL) { func handleListener(conn *tls.Conn, outputUrl *url.URL) {
slog.Debug(fmt.Sprintf("%s -> %s", conn.RemoteAddr(), outputUrl.Host))
c, err := net.Dial(outputUrl.Scheme, outputUrl.Host)
if err != nil {
writeError(conn, err)
conn.Close()
return
}
defer c.Close()
wg := new(sync.WaitGroup)
wg.Add(2)
go func() {
defer wg.Done()
_, _ = io.Copy(conn, c)
}()
go func() {
defer wg.Done()
_, _ = io.Copy(c, conn)
}()
wg.Wait()
}
func handleUnix(conn *tls.Conn, outputUrl *url.URL) {
slog.Debug(fmt.Sprintf("%s -> %s", conn.RemoteAddr(), outputUrl.Host+outputUrl.Path)) slog.Debug(fmt.Sprintf("%s -> %s", conn.RemoteAddr(), outputUrl.Host+outputUrl.Path))
c, err := net.Dial(outputUrl.Scheme, outputUrl.Host+outputUrl.Path) c, err := net.Dial(outputUrl.Scheme, outputUrl.Host+outputUrl.Path)
if err != nil { if err != nil {
writeError(conn, err) fmt.Fprintf(conn, "HTTP/1.1 500 Internal Server Error\r\nConnection: Close\r\nContent-Type: text/plain\r\n\r\n%s", err)
conn.Close() conn.Close()
return return
} }
@ -136,7 +96,3 @@ func handleUnix(conn *tls.Conn, outputUrl *url.URL) {
wg.Wait() wg.Wait()
} }
func writeError(w io.Writer, err error) {
fmt.Fprintf(w, "HTTP/1.1 500 Internal Server Error\r\nConnection: Close\r\nContent-Type: text/plain\r\n\r\n%s", err)
}