Bind to user-specified address (BindAddress6)
[mgsmtp.git] / Network.pas
index d5b458da2d5919570e2805ae5416b63eac408752..a28ff634077b64e0a77e8c36be148ead9d515017 100644 (file)
 unit Network;
 
 interface
-uses Classes, Sockets, SocketUtils, DNSResolve, NetRFC, Common;
+uses Classes, Sockets, SocketUtils, ctypes, DNSResolve, NetRFC, Common;
 
 const
 
+   { Address families: }
+   { These are here so users of this unit don't necessarily have to
+     use Sockets as well. }
+   AF_UNSPEC            =  Sockets.AF_UNSPEC;
+   AF_INET              =  Sockets.AF_INET;
+   AF_INET6             =  Sockets.AF_INET6;
+
    { Connection feature requests: }
    NET_TCP_BASIC        =  0;
    NET_TCP_RFCSUPPORT   =  1;
@@ -54,15 +61,21 @@ type
    TTCPConnection = class
       constructor Create; overload;
       constructor Create(const HostName: string; Port: word); overload;
-      constructor Create(Socket: socket; const Addr: TSockAddr); overload;
+      constructor Create(Socket: socket); overload;
       destructor Destroy; override;
    private
       FConnected: boolean;
       FSocket: socket;
       FHostIP: TIPNamePair;
       FSockTimeOut: DWord;
-      SockAddr: TSockAddr;
+      SrcSockAddr: TSockAddr;
+      SrcSockAddr6: TSockAddr6;
+      DstSockAddr: TSockAddr6;
+   protected
+      function IsNullAddress(SockAddr: PSockAddr): boolean;
+      function BindSrcAddr(Socket: socket; Family: word): cint;
    public
+      function SetBindAddress(Family: word; const HostName: string): boolean;
       function Connect(const HostName: string; Port: word): boolean;
       procedure Disconnect;
       procedure ReverseDNSLookup;
@@ -99,14 +112,15 @@ type
    end;
 
    TTCPListener = class(TThread)
-      constructor Create(const Address: string; Port: word; FeatureRequest: word);
+      constructor Create(const Address: string; Port: word; Family: word; FeatureRequest: word);
       {destructor Destroy; override;}
    private
       FFeatureRequest: word;
+      FFamily: word;
       FListenAddress: string;
       FListenPort: word;
       FListenSocket: socket;
-      SockAddr: TSockAddr;
+      SockAddr: TSockAddr6;
    protected
       procedure HandleClient(Connection: TTCPConnection); virtual; abstract;
       procedure Execute; override;
@@ -128,6 +142,11 @@ begin
    FConnected:= false;
    FSocket:= -1;
    FSockTimeOut:= DEF_SOCK_TIMEOUT;
+   FillChar(SrcSockAddr, SizeOf(SrcSockAddr), 0);
+   FillChar(SrcSockAddr6, SizeOf(SrcSockAddr6), 0);
+   FillChar(DstSockAddr, SizeOf(DstSockAddr), 0);
+   SrcSockAddr.sin_family:= AF_INET;
+   SrcSockAddr6.sin6_family:= AF_INET6;
 end;
 
 constructor TTCPConnection.Create(const HostName: string; Port: word);
@@ -137,13 +156,17 @@ begin
    Connect(HostName, Port);
 end;
 
-constructor TTCPConnection.Create(Socket: socket; const Addr: TSockAddr);
+constructor TTCPConnection.Create(Socket: socket);
 { Use an already connected socket. }
+var ssocklen, dsocklen: TSockLen;
 begin
    inherited Create;
    FSocket:= Socket;
-   SockAddr:= Addr;
-   FHostIP:= TIPNamePair.Create('', NetAddrToStr(Addr.sin_addr));
+   ssocklen:= SizeOf(SrcSockAddr);
+   dsocklen:= SizeOf(DstSockAddr);
+   fpgetsockname(FSocket, @SrcSockAddr, @ssocklen);
+   fpgetpeername(FSocket, @DstSockAddr, @dsocklen);
+   FHostIP:= TIPNamePair.Create('', IPToStr(@DstSockAddr));
    FConnected:= true;
 end;
 
@@ -164,8 +187,9 @@ begin
 end;
 
 
-constructor TTCPListener.Create(const Address: string; Port: word; FeatureRequest: word);
+constructor TTCPListener.Create(const Address: string; Port: word; Family: word; FeatureRequest: word);
 begin
+   FFamily:= Family;
    FListenAddress:= Address;
    FListenPort:= Port;
    FFeatureRequest:= FeatureRequest;
@@ -175,30 +199,91 @@ begin
 end;
 
 
+function TTCPConnection.IsNullAddress(SockAddr: PSockAddr): boolean;
+begin
+   if SockAddr^.sin_family = AF_INET then
+      Result:= SockAddr^.sin_addr.s_addr = 0
+   else if SockAddr^.sin_family = AF_INET6 then
+      Result:= (PSockAddr6(SockAddr)^.sin6_addr.u6_addr32[0] = 0)
+         and (PSockAddr6(SockAddr)^.sin6_addr.u6_addr32[1] = 0)
+         and (PSockAddr6(SockAddr)^.sin6_addr.u6_addr32[2] = 0)
+         and (PSockAddr6(SockAddr)^.sin6_addr.u6_addr32[3] = 0)
+   else
+      Result:= true;
+end;
+
+function TTCPConnection.BindSrcAddr(Socket: socket; Family: word): cint;
+var SockAddr: PSockAddr; addrlen: size_t;
+begin
+   case Family of
+      AF_INET:
+         begin
+            SockAddr:= @SrcSockAddr;
+            addrlen:= SizeOf(SrcSockAddr);
+         end;
+      AF_INET6:
+         begin
+            SockAddr:= @SrcSockAddr6;
+            addrlen:= SizeOf(SrcSockAddr6);
+         end;
+   end;
+
+   if not IsNullAddress(SockAddr) then
+      Result:= fpBind(Socket, SockAddr, addrlen)
+   else
+      Result:= 0;
+end;
+
+function TTCPConnection.SetBindAddress(Family: word; const HostName: string): boolean;
+var GAIResult: TGAIResult; SockAddr: PSockAddr;
+begin
+   GAIResult:= ResolveHost(HostName, Family);
+   if GAIResult.GAIError = 0 then begin
+      case GAIResult.AddrInfo^.ai_family of
+         AF_INET:    SockAddr:= @SrcSockAddr;
+         AF_INET6:   SockAddr:= @SrcSockAddr6;
+      end;
+      Move(GAIResult.AddrInfo^.ai_addr^, SockAddr^, GAIResult.AddrInfo^.ai_addrlen);
+      FreeHost(GAIResult);
+      Result:= true;
+   end
+   else
+      Result:= false;
+end;
+
 function TTCPConnection.Connect(const HostName: string; Port: word): boolean;
 { Resolves the given hostname, and tries to connect it on the given port. }
 var GAIResult: TGAIResult;
 begin
-   FSocket:= fpSocket(af_inet, sock_stream, 0);
-   if (FSocket <> -1) then begin
-      GAIResult:= ResolveHost(HostName);
-         if GAIResult.GAIError = 0 then begin
-            SockAddr:= GAIResult.AddrInfo^.ai_addr^;
-                SockAddr.sin_port:= htons(Port);
-
-                if SockAddr.sin_addr.s_addr <> 0 then
-                   { Try to initiate connection. }
-                   FConnected:= fpConnect(FSocket, @SockAddr, SizeOf(SockAddr)) <> -1;
-
-                if FConnected then begin
-                   FHostIP:= TIPNamePair.Create(HostName, NetAddrToStr(SockAddr.sin_addr));
-                   SetSockTimeOut(FSockTimeOut);
-                end
-                else
-                   CloseSocket(FSocket);
+   GAIResult:= ResolveHost(HostName, AF_UNSPEC);
+   if GAIResult.GAIError = 0 then begin
+      Move(GAIResult.AddrInfo^.ai_addr^, DstSockAddr, GAIResult.AddrInfo^.ai_addrlen);
+      DstSockAddr.sin6_port:= htons(Port);
+
+      { Create socket. }
+      FSocket:= fpSocket(GAIResult.AddrInfo^.ai_family, SOCK_STREAM, 0);
+
+      if (FSocket <> -1) then begin
+
+         if BindSrcAddr(FSocket, GAIResult.AddrInfo^.ai_family) = 0 then begin
+
+            { Try to initiate connection. }
+            FConnected:= fpConnect(FSocket, @DstSockAddr, GAIResult.AddrInfo^.ai_addrlen) <> -1;
+
+            if FConnected then begin
+               FHostIP:= TIPNamePair.Create(HostName, IPToStr(@DstSockAddr));
+               SetSockTimeOut(FSockTimeOut);
+            end
+            else
+               CloseSocket(FSocket);
+
+         end
+         else
+            CloseSocket(FSocket);
 
-         FreeHost(GAIResult);
       end;
+
+      FreeHost(GAIResult);
    end;
    Result:= FConnected;
 end;
@@ -217,7 +302,7 @@ procedure TTCPConnection.ReverseDNSLookup;
 var NHostIP: TIPNamePair;
 begin
    if FConnected then begin
-      NHostIP:= TIPNamePair.Create(ResolveIP(@SockAddr), FHostIP.IP);
+      NHostIP:= TIPNamePair.Create(ResolveIP(PSockAddr(@DstSockAddr)), FHostIP.IP);
       FHostIP.Free;
       FHostIP:= NHostIP;
    end;
@@ -227,13 +312,13 @@ function TTCPConnection.VerifyFCrDNS: boolean;
 var GAIResult: TGAIResult; ai: PAddrInfo;
 begin
    Result:= false;
-   GAIResult:= ResolveHost(HostIP.Name);
+   GAIResult:= ResolveHost(HostIP.Name, AF_UNSPEC);
    if GAIResult.GAIError = 0 then begin
       ai:= GAIResult.AddrInfo;
-         { One of the addresses must match. }
-         while (ai <> nil) and not Result do begin
-            Result:= NetAddrToStr(ai^.ai_addr^.sin_addr) = HostIP.IP;
-                ai:= ai^.ai_next;
+      { One of the addresses must match. }
+      while (ai <> nil) and not Result do begin
+         Result:= IPToStr(ai^.ai_addr) = HostIP.IP;
+         ai:= ai^.ai_next;
       end;
    end;
 end;
@@ -303,14 +388,14 @@ end;
 function TTCPListener.StartListen: boolean;
 var GAIResult: TGAIResult;
 begin
-   FListenSocket:= fpSocket(af_inet, sock_stream, 0);
+   FListenSocket:= fpSocket(FFamily, SOCK_STREAM, 0);
    if FListenSocket <> -1 then begin
-      GAIResult:= ResolveHost(FListenAddress);
-         if GAIResult.GAIError = 0 then begin
-            SockAddr:= GAIResult.AddrInfo^.ai_addr^;
-                SockAddr.sin_port:= htons(FListenPort);
+      GAIResult:= ResolveHost(FListenAddress, FFamily);
+      if GAIResult.GAIError = 0 then begin
+         Move(GAIResult.AddrInfo^.ai_addr^, SockAddr, GAIResult.AddrInfo^.ai_addrlen);
+         SockAddr.sin6_port:= htons(FListenPort);
 
-         if fpBind(FListenSocket, @SockAddr, SizeOf(SockAddr)) <> -1 then begin
+         if fpBind(FListenSocket, @SockAddr, GAIResult.AddrInfo^.ai_addrlen) <> -1 then begin
             { It seems the maximum connection value isn't enforced by the
               Free Pascal library, so this 512 is a constant, dummy value. }
             if fpListen(FListenSocket, 512) <> -1 then begin
@@ -350,9 +435,9 @@ begin
            connection. }
          case FFeatureRequest of
          NET_TCP_BASIC:
-            TCPConnection:= TTCPConnection.Create(ClientSocket, SockAddr);
+            TCPConnection:= TTCPConnection.Create(ClientSocket);
          NET_TCP_RFCSUPPORT:
-            TCPConnection:= TTCPRFCConnection.Create(ClientSocket, SockAddr);
+            TCPConnection:= TTCPRFCConnection.Create(ClientSocket);
          end;
 
          { Then start a new thread with the connection handler. }