Bind to user-specified address (BindAddress6)
[mgsmtp.git] / Network.pas
index 10a95f991d3898060489ecfdc31d252c242b81d7..a28ff634077b64e0a77e8c36be148ead9d515017 100644 (file)
 unit Network;
 
 interface
-uses Classes, Sockets, SocketUtils, DNSResolve, NetRFC, Common;
+uses Classes, Sockets, SocketUtils, ctypes, DNSResolve, NetRFC, Common;
 
 const
 
+   { Address families: }
+   { These are here so users of this unit don't necessarily have to
+     use Sockets as well. }
+   AF_UNSPEC            =  Sockets.AF_UNSPEC;
+   AF_INET              =  Sockets.AF_INET;
+   AF_INET6             =  Sockets.AF_INET6;
+
    { Connection feature requests: }
    NET_TCP_BASIC        =  0;
    NET_TCP_RFCSUPPORT   =  1;
@@ -54,15 +61,21 @@ type
    TTCPConnection = class
       constructor Create; overload;
       constructor Create(const HostName: string; Port: word); overload;
-      constructor Create(Socket: socket; const Addr: TSockAddr); overload;
+      constructor Create(Socket: socket); overload;
       destructor Destroy; override;
    private
       FConnected: boolean;
       FSocket: socket;
       FHostIP: TIPNamePair;
       FSockTimeOut: DWord;
-      sAddr: TSockAddr;
+      SrcSockAddr: TSockAddr;
+      SrcSockAddr6: TSockAddr6;
+      DstSockAddr: TSockAddr6;
+   protected
+      function IsNullAddress(SockAddr: PSockAddr): boolean;
+      function BindSrcAddr(Socket: socket; Family: word): cint;
    public
+      function SetBindAddress(Family: word; const HostName: string): boolean;
       function Connect(const HostName: string; Port: word): boolean;
       procedure Disconnect;
       procedure ReverseDNSLookup;
@@ -99,14 +112,15 @@ type
    end;
 
    TTCPListener = class(TThread)
-      constructor Create(const Address: string; Port: word; FeatureRequest: word);
+      constructor Create(const Address: string; Port: word; Family: word; FeatureRequest: word);
       {destructor Destroy; override;}
    private
       FFeatureRequest: word;
+      FFamily: word;
       FListenAddress: string;
       FListenPort: word;
       FListenSocket: socket;
-      sAddr: TSockAddr;
+      SockAddr: TSockAddr6;
    protected
       procedure HandleClient(Connection: TTCPConnection); virtual; abstract;
       procedure Execute; override;
@@ -128,22 +142,31 @@ begin
    FConnected:= false;
    FSocket:= -1;
    FSockTimeOut:= DEF_SOCK_TIMEOUT;
+   FillChar(SrcSockAddr, SizeOf(SrcSockAddr), 0);
+   FillChar(SrcSockAddr6, SizeOf(SrcSockAddr6), 0);
+   FillChar(DstSockAddr, SizeOf(DstSockAddr), 0);
+   SrcSockAddr.sin_family:= AF_INET;
+   SrcSockAddr6.sin6_family:= AF_INET6;
 end;
 
 constructor TTCPConnection.Create(const HostName: string; Port: word);
 { Connect to the given port on the given hostname. }
 begin
-   inherited Create;
+   Create;
    Connect(HostName, Port);
 end;
 
-constructor TTCPConnection.Create(Socket: socket; const Addr: TSockAddr);
+constructor TTCPConnection.Create(Socket: socket);
 { Use an already connected socket. }
+var ssocklen, dsocklen: TSockLen;
 begin
    inherited Create;
    FSocket:= Socket;
-   sAddr:= Addr;
-   FHostIP:= TIPNamePair.Create('', NetAddrToStr(Addr.sin_addr));
+   ssocklen:= SizeOf(SrcSockAddr);
+   dsocklen:= SizeOf(DstSockAddr);
+   fpgetsockname(FSocket, @SrcSockAddr, @ssocklen);
+   fpgetpeername(FSocket, @DstSockAddr, @dsocklen);
+   FHostIP:= TIPNamePair.Create('', IPToStr(@DstSockAddr));
    FConnected:= true;
 end;
 
@@ -164,38 +187,103 @@ begin
 end;
 
 
-constructor TTCPListener.Create(const Address: string; Port: word; FeatureRequest: word);
+constructor TTCPListener.Create(const Address: string; Port: word; Family: word; FeatureRequest: word);
 begin
+   FFamily:= Family;
    FListenAddress:= Address;
    FListenPort:= Port;
    FFeatureRequest:= FeatureRequest;
    FreeOnTerminate:= false;
+   FillChar(SockAddr, SizeOf(SockAddr), 0);
    inherited Create(true);
 end;
 
 
+function TTCPConnection.IsNullAddress(SockAddr: PSockAddr): boolean;
+begin
+   if SockAddr^.sin_family = AF_INET then
+      Result:= SockAddr^.sin_addr.s_addr = 0
+   else if SockAddr^.sin_family = AF_INET6 then
+      Result:= (PSockAddr6(SockAddr)^.sin6_addr.u6_addr32[0] = 0)
+         and (PSockAddr6(SockAddr)^.sin6_addr.u6_addr32[1] = 0)
+         and (PSockAddr6(SockAddr)^.sin6_addr.u6_addr32[2] = 0)
+         and (PSockAddr6(SockAddr)^.sin6_addr.u6_addr32[3] = 0)
+   else
+      Result:= true;
+end;
+
+function TTCPConnection.BindSrcAddr(Socket: socket; Family: word): cint;
+var SockAddr: PSockAddr; addrlen: size_t;
+begin
+   case Family of
+      AF_INET:
+         begin
+            SockAddr:= @SrcSockAddr;
+            addrlen:= SizeOf(SrcSockAddr);
+         end;
+      AF_INET6:
+         begin
+            SockAddr:= @SrcSockAddr6;
+            addrlen:= SizeOf(SrcSockAddr6);
+         end;
+   end;
+
+   if not IsNullAddress(SockAddr) then
+      Result:= fpBind(Socket, SockAddr, addrlen)
+   else
+      Result:= 0;
+end;
+
+function TTCPConnection.SetBindAddress(Family: word; const HostName: string): boolean;
+var GAIResult: TGAIResult; SockAddr: PSockAddr;
+begin
+   GAIResult:= ResolveHost(HostName, Family);
+   if GAIResult.GAIError = 0 then begin
+      case GAIResult.AddrInfo^.ai_family of
+         AF_INET:    SockAddr:= @SrcSockAddr;
+         AF_INET6:   SockAddr:= @SrcSockAddr6;
+      end;
+      Move(GAIResult.AddrInfo^.ai_addr^, SockAddr^, GAIResult.AddrInfo^.ai_addrlen);
+      FreeHost(GAIResult);
+      Result:= true;
+   end
+   else
+      Result:= false;
+end;
+
 function TTCPConnection.Connect(const HostName: string; Port: word): boolean;
 { Resolves the given hostname, and tries to connect it on the given port. }
+var GAIResult: TGAIResult;
 begin
-   FSocket:= fpSocket(af_inet, sock_stream, 0);
-   if (FSocket <> -1) then begin
-      with sAddr do begin
-         sin_family:= af_inet;
-         sin_port:= htons(Port);
-         { Resolve hostname to IP address. }
-         sin_addr:= ResolveHost(HostName);
-      end;
+   GAIResult:= ResolveHost(HostName, AF_UNSPEC);
+   if GAIResult.GAIError = 0 then begin
+      Move(GAIResult.AddrInfo^.ai_addr^, DstSockAddr, GAIResult.AddrInfo^.ai_addrlen);
+      DstSockAddr.sin6_port:= htons(Port);
 
-      if sAddr.sin_addr.s_addr <> 0 then
-         { Try to initiate connection. }
-         FConnected:= fpConnect(FSocket, @sAddr, SizeOf(sAddr)) <> -1;
+      { Create socket. }
+      FSocket:= fpSocket(GAIResult.AddrInfo^.ai_family, SOCK_STREAM, 0);
 
-      if FConnected then begin
-         FHostIP:= TIPNamePair.Create(HostName, NetAddrToStr(sAddr.sin_addr));
-         SetSockTimeOut(FSockTimeOut);
-      end
-      else
-         CloseSocket(FSocket);
+      if (FSocket <> -1) then begin
+
+         if BindSrcAddr(FSocket, GAIResult.AddrInfo^.ai_family) = 0 then begin
+
+            { Try to initiate connection. }
+            FConnected:= fpConnect(FSocket, @DstSockAddr, GAIResult.AddrInfo^.ai_addrlen) <> -1;
+
+            if FConnected then begin
+               FHostIP:= TIPNamePair.Create(HostName, IPToStr(@DstSockAddr));
+               SetSockTimeOut(FSockTimeOut);
+            end
+            else
+               CloseSocket(FSocket);
+
+         end
+         else
+            CloseSocket(FSocket);
+
+      end;
+
+      FreeHost(GAIResult);
    end;
    Result:= FConnected;
 end;
@@ -214,15 +302,25 @@ procedure TTCPConnection.ReverseDNSLookup;
 var NHostIP: TIPNamePair;
 begin
    if FConnected then begin
-      NHostIP:= TIPNamePair.Create(ResolveIP(sAddr.sin_addr), FHostIP.IP);
+      NHostIP:= TIPNamePair.Create(ResolveIP(PSockAddr(@DstSockAddr)), FHostIP.IP);
       FHostIP.Free;
       FHostIP:= NHostIP;
    end;
 end;
 
 function TTCPConnection.VerifyFCrDNS: boolean;
+var GAIResult: TGAIResult; ai: PAddrInfo;
 begin
-   Result:= NetAddrToStr(ResolveHost(HostIP.Name)) = HostIP.IP;
+   Result:= false;
+   GAIResult:= ResolveHost(HostIP.Name, AF_UNSPEC);
+   if GAIResult.GAIError = 0 then begin
+      ai:= GAIResult.AddrInfo;
+      { One of the addresses must match. }
+      while (ai <> nil) and not Result do begin
+         Result:= IPToStr(ai^.ai_addr) = HostIP.IP;
+         ai:= ai^.ai_next;
+      end;
+   end;
 end;
 
 procedure TTCPConnection.SetSockTimeOut(TimeOut: DWord);
@@ -288,24 +386,29 @@ end;
 
 
 function TTCPListener.StartListen: boolean;
+var GAIResult: TGAIResult;
 begin
-   FListenSocket:= fpSocket(af_inet, sock_stream, 0);
+   FListenSocket:= fpSocket(FFamily, SOCK_STREAM, 0);
    if FListenSocket <> -1 then begin
-      with sAddr do begin
-         sin_family:= af_inet;
-         sin_port:= htons(FListenPort);
-         sin_addr:= ResolveHost(FListenAddress);
-      end;
-      if fpBind(FListenSocket, @sAddr, sizeof(sAddr)) <> -1 then begin
-         { It seems the maximum connection value isn't enforced by the
-           Free Pascal library, so this 512 is a constant, dummy value. }
-         if fpListen(FListenSocket, 512) <> -1 then begin
-            Result:= true;
-            Start;
+      GAIResult:= ResolveHost(FListenAddress, FFamily);
+      if GAIResult.GAIError = 0 then begin
+         Move(GAIResult.AddrInfo^.ai_addr^, SockAddr, GAIResult.AddrInfo^.ai_addrlen);
+         SockAddr.sin6_port:= htons(FListenPort);
+
+         if fpBind(FListenSocket, @SockAddr, GAIResult.AddrInfo^.ai_addrlen) <> -1 then begin
+            { It seems the maximum connection value isn't enforced by the
+              Free Pascal library, so this 512 is a constant, dummy value. }
+            if fpListen(FListenSocket, 512) <> -1 then begin
+               Result:= true;
+               Start;
+            end
+            else Result:= false;
          end
          else Result:= false;
+
+         FreeHost(GAIResult);
       end
-      else Result:= false;
+         else Result:= false;
    end
    else Result:= false;
 end;
@@ -323,8 +426,8 @@ begin
    { Now, accept connections. }
    AcceptFailCount:= 0;
    while not Terminated do begin
-      Len:= SizeOf(sAddr);
-      ClientSocket:= fpAccept(FListenSocket, @sAddr, @Len);
+      Len:= SizeOf(SockAddr);
+      ClientSocket:= fpAccept(FListenSocket, @SockAddr, @Len);
       if ClientSocket <> -1 then begin
          AcceptFailCount:= 0;
 
@@ -332,9 +435,9 @@ begin
            connection. }
          case FFeatureRequest of
          NET_TCP_BASIC:
-            TCPConnection:= TTCPConnection.Create(ClientSocket, sAddr);
+            TCPConnection:= TTCPConnection.Create(ClientSocket);
          NET_TCP_RFCSUPPORT:
-            TCPConnection:= TTCPRFCConnection.Create(ClientSocket, sAddr);
+            TCPConnection:= TTCPRFCConnection.Create(ClientSocket);
          end;
 
          { Then start a new thread with the connection handler. }