From d9c50d677f51240cbae68c551eac562ee48bb443 Mon Sep 17 00:00:00 2001 From: Warmist Date: Sat, 15 Aug 2015 15:09:12 +0300 Subject: [PATCH] A lua interface for csockets in a spirit of luasocket --- plugins/lua/luasocket.lua | 67 ++++++++ plugins/luasocket.cpp | 352 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 419 insertions(+) create mode 100644 plugins/lua/luasocket.lua create mode 100644 plugins/luasocket.cpp diff --git a/plugins/lua/luasocket.lua b/plugins/lua/luasocket.lua new file mode 100644 index 000000000..34cfeed38 --- /dev/null +++ b/plugins/lua/luasocket.lua @@ -0,0 +1,67 @@ +local _ENV = mkmodule('plugins.luasocket') +local _funcs={} +for k,v in pairs(_ENV) do + if type(v)=="function" then + _funcs[k]=v + _ENV[k]=nil + end +end + +local socket=defclass(socket) +socket.ATTRS={ + server_id=-1, + client_id=-1, +} + +function socket:close( ) + if self.client_id==-1 then + _funcs.lua_server_close(self.server_id) + else + _funcs.lua_client_close(self.server_id,self.client_id) + end +end +function socket:setTimeout( sec,msec ) + msec=msec or 0 + _funcs.lua_socket_set_timeout(self.server_id,self.client_id,sec,msec) +end + +local client=defclass(client,socket) +function client:receive( pattern ) + local pattern=pattern or "*l" + local bytes=-1 + if type(pattern)== number then + bytes=pattern + end + local ret=_funcs.lua_client_receive(self.server_id,self.client_id,bytes,pattern,false) + if ret=="" then + return + else + return ret + end +end +function client:send( data ) + _funcs.lua_client_send(self.server_id,self.client_id,data) +end + + +local server=defclass(server,socket) +function server:accept() + local id=_funcs.lua_server_accept(self.server_id,false) + if id~=nil then + return client{server_id=self.server_id,client_id=id} + else + return + end +end + +tcp={} +function tcp:bind( address,port ) + local id=_funcs.lua_socket_bind(address,port) + return server{server_id=id} +end +function tcp:connect( address,port ) + local id=_funcs.lua_socket_connect(address,port) + return client{client_id=id} +end +--TODO garbage collect stuff +return _ENV \ No newline at end of file diff --git a/plugins/luasocket.cpp b/plugins/luasocket.cpp new file mode 100644 index 000000000..e0a41a8b0 --- /dev/null +++ b/plugins/luasocket.cpp @@ -0,0 +1,352 @@ +#include "Core.h" +#include "Console.h" +#include "Export.h" +#include "PluginManager.h" +#include "DataDefs.h" + +#include +#include +#include +#include +#include +#include +#include "MiscUtils.h" +#include "LuaTools.h" +#include "DataFuncs.h" +#include //todo convert errors to lua-errors and co. Then remove this + +using namespace DFHack; +using namespace df::enums; +struct server +{ + CPassiveSocket *socket; + std::map clients; + int last_client_id; + void close(); +}; +std::map servers; +typedef std::map clients_map; +clients_map clients; //free clients, i.e. non-server spawned clients +DFHACK_PLUGIN("luasocket"); + +// The error messages are taken from the clsocket source code +const char * translate_socket_error(CSimpleSocket::CSocketError err) { + switch (err) { + case CSimpleSocket::SocketError: + return "Generic socket error translates to error below."; + case CSimpleSocket::SocketSuccess: + return "No socket error."; + case CSimpleSocket::SocketInvalidSocket: + return "Invalid socket handle."; + case CSimpleSocket::SocketInvalidAddress: + return "Invalid destination address specified."; + case CSimpleSocket::SocketInvalidPort: + return "Invalid destination port specified."; + case CSimpleSocket::SocketConnectionRefused: + return "No server is listening at remote address."; + case CSimpleSocket::SocketTimedout: + return "Timed out while attempting operation."; + case CSimpleSocket::SocketEwouldblock: + return "Operation would block if socket were blocking."; + case CSimpleSocket::SocketNotconnected: + return "Currently not connected."; + case CSimpleSocket::SocketEinprogress: + return "Socket is non-blocking and the connection cannot be completed immediately"; + case CSimpleSocket::SocketInterrupted: + return "Call was interrupted by a signal that was caught before a valid connection arrived."; + case CSimpleSocket::SocketConnectionAborted: + return "The connection has been aborted."; + case CSimpleSocket::SocketProtocolError: + return "Invalid protocol for operation."; + case CSimpleSocket::SocketFirewallError: + return "Firewall rules forbid connection."; + case CSimpleSocket::SocketInvalidSocketBuffer: + return "The receive buffer point outside the process's address space."; + case CSimpleSocket::SocketConnectionReset: + return "Connection was forcibly closed by the remote host."; + case CSimpleSocket::SocketAddressInUse: + return "Address already in use."; + case CSimpleSocket::SocketInvalidPointer: + return "Pointer type supplied as argument is invalid."; + case CSimpleSocket::SocketEunknown: + return "Unknown error please report to mark@carrierlabs.com"; + default: + return "No such CSimpleSocket error"; + } +} +void server::close() +{ + for(auto it=clients.begin();it!=clients.end();it++) + { + CActiveSocket* sock=it->second; + sock->Close(); + delete sock; + } + clients.clear(); + socket->Close(); + delete socket; +} +std::pair get_client(int server_id,int client_id) +{ + std::map* target=&clients; + if(server_id>0) + { + if(servers.count(server_id)==0) + { + throw std::runtime_error("Server with this id does not exist"); + } + server &cur_server=servers[server_id]; + target=&cur_server.clients; + } + + if(target->count(client_id)==0) + { + throw std::runtime_error("Client does with this id not exist"); + } + CActiveSocket *sock=(*target)[client_id]; + return std::make_pair(sock,target); +} +void handle_error(CSimpleSocket::CSocketError err,bool skip_timeout=true) +{ + if(err==CSimpleSocket::SocketSuccess) + return; + if(err==CSimpleSocket::SocketTimedout && skip_timeout) + return; + throw std::runtime_error(translate_socket_error(err)); +} +static int lua_socket_bind(std::string ip,int port) +{ + static int server_id=0; + CPassiveSocket* sock=new CPassiveSocket; + if(!sock->Initialize()) + { + CSimpleSocket::CSocketError err=sock->GetSocketError(); + delete sock; + handle_error(err,false); + } + sock->SetBlocking(); + if(!sock->Listen((uint8_t*)ip.c_str(),port)) + { + handle_error(sock->GetSocketError(),false); + } + server_id++; + server& cur_server=servers[server_id]; + cur_server.socket=sock; + cur_server.last_client_id=0; + return server_id; +} +static int lua_server_accept(int id,bool fail_on_timeout) +{ + if(servers.count(id)==0) + { + throw std::runtime_error("Server not bound"); + } + server &cur_server=servers[id]; + CActiveSocket* sock=cur_server.socket->Accept(); + if(!sock) + { + handle_error(sock->GetSocketError(),!fail_on_timeout); + return 0; + } + else + { + cur_server.last_client_id++; + cur_server.clients[cur_server.last_client_id]=sock; + return cur_server.last_client_id; + } +} +static void lua_client_close(int server_id,int client_id) +{ + auto info=get_client(server_id,client_id); + + CActiveSocket *sock=info.first; + std::map* target=info.second; + + target->erase(client_id); + CSimpleSocket::CSocketError err=CSimpleSocket::SocketSuccess; + if(!sock->Close()) + err=sock->GetSocketError(); + delete sock; + if(err!=CSimpleSocket::SocketSuccess) + { + throw std::runtime_error(translate_socket_error(err)); + } +} +static void lua_server_close(int server_id) +{ + if(servers.count(server_id)==0) + { + throw std::runtime_error("Server with this id does not exist"); + } + server &cur_server=servers[server_id]; + try{ + cur_server.close(); + } + catch(...) + { + servers.erase(server_id); + throw; + } +} +static std::string lua_client_receive(int server_id,int client_id,int bytes,std::string pattern,bool fail_on_timeout) +{ + auto info=get_client(server_id,client_id); + CActiveSocket *sock=info.first; + if(bytes>0) + { + if(sock->Receive(bytes)<=0) + { + throw std::runtime_error(translate_socket_error(sock->GetSocketError())); + } + return std::string((char*)sock->GetData(),bytes); + } + else + { + std::string ret; + if(pattern=="*a") //?? + { + while(true) + { + int received=sock->Receive(1); + if(received<0) + { + handle_error(sock->GetSocketError(),!fail_on_timeout); + return "";//maybe return partial string? + } + else if(received==0) + { + break; + } + ret+=(char)*sock->GetData(); + } + return ret; + } + else if (pattern=="" || pattern=="*l") + { + while(true) + { + + if(sock->Receive(1)<=0) + { + handle_error(sock->GetSocketError(),!fail_on_timeout); + return "";//maybe return partial string? + } + char rec=(char)*sock->GetData(); + if(rec=='\n') + break; + ret+=rec; + } + return ret; + } + else + { + throw std::runtime_error("Unsupported receive pattern"); + } + } +} +static void lua_client_send(int server_id,int client_id,std::string data) +{ + if(data.size()==0) + return; + std::map* target=&clients; + if(server_id>0) + { + if(servers.count(server_id)==0) + { + throw std::runtime_error("Server with this id does not exist"); + } + server &cur_server=servers[server_id]; + target=&cur_server.clients; + } + + if(target->count(client_id)==0) + { + throw std::runtime_error("Client does with this id not exist"); + } + CActiveSocket *sock=(*target)[client_id]; + if(sock->Send((const uint8_t*)data.c_str(),data.size())!=data.size()) + { + throw std::runtime_error(translate_socket_error(sock->GetSocketError())); + } +} +static int lua_socket_connect(std::string ip,int port) +{ + static int last_client_id=0; + CActiveSocket* sock=new CActiveSocket; + if(!sock->Initialize()) + { + CSimpleSocket::CSocketError err=sock->GetSocketError(); + delete sock; + throw std::runtime_error(translate_socket_error(err)); + } + if(!sock->Open((const uint8_t*)ip.c_str(),port)) + { + CSimpleSocket::CSocketError err=sock->GetSocketError(); + delete sock; + throw std::runtime_error(translate_socket_error(err)); + } + last_client_id++; + clients[last_client_id]=sock; + return last_client_id; +} +static void lua_socket_set_timeout(int server_id,int client_id,int32_t sec,int32_t msec) +{ + std::map* target=&clients; + if(server_id>0) + { + if(servers.count(server_id)==0) + { + throw std::runtime_error("Server with this id does not exist"); + } + server &cur_server=servers[server_id]; + if(client_id==-1) + { + cur_server.socket->SetConnectTimeout(sec,msec); + cur_server.socket->SetReceiveTimeout(sec,msec); + cur_server.socket->SetSendTimeout(sec,msec); + return; + } + target=&cur_server.clients; + } + + if(target->count(client_id)==0) + { + throw std::runtime_error("Client does with this id not exist"); + } + CActiveSocket *sock=(*target)[client_id]; + sock->SetConnectTimeout(sec,msec); + sock->SetReceiveTimeout(sec,msec); + sock->SetSendTimeout(sec,msec); +} +DFHACK_PLUGIN_LUA_FUNCTIONS { + DFHACK_LUA_FUNCTION(lua_socket_bind), //spawn a server + DFHACK_LUA_FUNCTION(lua_socket_connect),//spawn a client (i.e. connection) + DFHACK_LUA_FUNCTION(lua_socket_set_timeout), + DFHACK_LUA_FUNCTION(lua_server_accept), + DFHACK_LUA_FUNCTION(lua_server_close), + DFHACK_LUA_FUNCTION(lua_client_close), + DFHACK_LUA_FUNCTION(lua_client_send), + DFHACK_LUA_FUNCTION(lua_client_receive), + DFHACK_LUA_END +}; +DFhackCExport command_result plugin_init ( color_ostream &out, std::vector &commands) +{ + + return CR_OK; +} +DFhackCExport command_result plugin_shutdown ( color_ostream &out ) +{ + for(auto it=clients.begin();it!=clients.end();it++) + { + CActiveSocket* sock=it->second; + sock->Close(); + delete sock; + } + clients.clear(); + for(auto it=servers.begin();it!=servers.end();it++) + { + it->second.close(); + } + servers.clear(); + return CR_OK; +} \ No newline at end of file