/* NetHalt - NetHalt client service * Copyright (C) 2009 Daniel Collins * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * * Neither the name of the author nor the names of its contributors may * be used to endorse or promote products derived from this software * without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ #define TIMEOUT 30 #include #include #include #include #include #include #include #include extern "C" { #include "lib.h" #include "ntservice.h" char const *svc_name = "nhclient"; char const *svc_log = "NetHalt Client"; } struct sdtime { int days; int hour; int min; int sec; struct sdtime *next; }; struct config { int use_server; std::string server_name; int server_port; int server_refresh; int warning; int abort; int delay; std::vector sdtimes; }; struct packet { int64_t sdtime; int32_t warning; int32_t abort; int32_t delay; } __attribute__((__packed__)); static HKEY regkey = NULL; static struct config config; static int listener = -1; static fd_set read_fds; static std::map clients; static time_t sdtime = 0; static std::string reg_get_string(char const *name); static DWORD reg_get_dword(char const *name); static void reg_set_string(char const *name, char const *value); static void reg_set_dword(char const *name, DWORD value); static void close_client(std::map::iterator client); static void add_sdtime(std::vector *sdtimes, char const *str); static time_t calc_sdtime(time_t now); static void update_client(int sockfd); static void update_all(void); static void load_sconfig(void); static std::string dl_config(void); void svc_init(void) { HANDLE token; TOKEN_PRIVILEGES tkp; if(!OpenProcessToken(GetCurrentProcess(), TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY, &token)) { die("Failed to open access token: %s", w32_error(GetLastError())); } LookupPrivilegeValue(NULL, SE_SHUTDOWN_NAME, &tkp.Privileges[0].Luid); tkp.PrivilegeCount = 1; tkp.Privileges[0].Attributes = SE_PRIVILEGE_ENABLED; if(!AdjustTokenPrivileges(token, FALSE, &tkp, 0, (PTOKEN_PRIVILEGES)NULL, 0)) { die("Failed to obtain shutdown privilege: %s", w32_error(GetLastError())); } DWORD errnum = RegOpenKeyEx( HKEY_LOCAL_MACHINE, "SOFTWARE\\NetHalt", 0, KEY_QUERY_VALUE | KEY_SET_VALUE, ®key ); if(errnum != ERROR_SUCCESS) { die("Failed to open registry key: %s", w32_error(errnum)); } config.use_server = reg_get_dword("use_server"); config.server_name = reg_get_string("server_name"); config.server_port = reg_get_dword("server_port"); config.server_refresh = reg_get_dword("server_refresh"); config.warning = reg_get_dword("warning"); config.abort = reg_get_dword("abort"); config.delay = reg_get_dword("delay"); std::string sdtimes = reg_get_string("sdtimes"); char const *sptr = sdtimes.c_str(); while(*sptr) { int i = strspn(sptr, "1234567890:;"); std::string s(sptr, i); sptr += i; sptr += strspn(sptr, ","); add_sdtime(&(config.sdtimes), s.c_str()); } WSADATA wsadata; errnum = WSAStartup(MAKEWORD(2,2), &wsadata); if(errnum != ERROR_SUCCESS) { die("Winsock initialization failed: %s", w32_error(errnum)); } listener = socket(AF_INET, SOCK_STREAM, 0); if(listener == -1) { die("Failed to create IPC socket: %s", w32_error(WSAGetLastError())); } struct sockaddr_in addr; addr.sin_family = AF_INET; addr.sin_port = 0; addr.sin_addr.s_addr = inet_addr("127.0.0.1"); if(bind(listener, (struct sockaddr*)&addr, sizeof(addr)) == -1) { die("Failed to bind IPC socket: %s", w32_error(WSAGetLastError())); } if(listen(listener, 16) == -1) { die("Failed to listen on IPC socket: %s", w32_error(WSAGetLastError())); } int size = sizeof(addr); getsockname(listener, (struct sockaddr*)&addr, &size); reg_set_dword("ipc_port", ntohs(addr.sin_port)); FD_ZERO(&read_fds); FD_SET(listener, &read_fds); } void svc_main(void) { fd_set select_fds; struct timeval timeout; int sret, i; /* TODO: Split this into another thread */ if(config.use_server) { load_sconfig(); } sdtime = calc_sdtime(time(NULL)); while(svc_status.dwCurrentState == SERVICE_RUNNING) { select_fds = read_fds; timeout.tv_sec = 1; timeout.tv_usec = 0; sret = select(0, &select_fds, NULL, NULL, &timeout); if(sret == -1) { die( "Failed to wait for socket events: %s", w32_error(WSAGetLastError()) ); } if(FD_ISSET(listener, &select_fds)) { int fd = accept(listener, NULL, NULL); if(fd == -1) { LogError("Failed to accept connection: %s", w32_error(WSAGetLastError())); }else{ struct linger lopt; lopt.l_onoff = 0; setsockopt(fd, SOL_SOCKET, SO_LINGER, (const char*)&lopt, sizeof(lopt)); clients.insert(std::make_pair(fd, time(NULL))); if(clients.size() == FD_SETSIZE) { FD_CLR(listener, &read_fds); } FD_SET(fd, &read_fds); update_client(fd); } } std::map::iterator client = clients.begin(); std::map::iterator end = clients.end(), next = client; time_t now = time(NULL); while(next != end) { client = next; next++; if(now > client->second + TIMEOUT) { close_client(client); continue; } if(!FD_ISSET(client->first, &select_fds)) { continue; } client->second = now; char cmd; i = recv(client->first, &cmd, 1, 0); if(i == -1) { LogError("Error reading from client: %s", w32_error(WSAGetLastError())); } if(i <= 0) { close_client(client); }else if(sdtime && config.warning && now >= sdtime-config.warning) { /* Only allow commands during the warning period */ if(cmd == 'A' && config.abort) { sdtime = calc_sdtime(sdtime); update_all(); break; } if(cmd == 'D' && config.delay) { sdtime += config.delay; update_all(); break; } } } if(now >= sdtime && sdtime) { if(!InitiateSystemShutdown(NULL, NULL, 0, TRUE, FALSE)) { LogError("Shutdown failed: %s", w32_error(GetLastError())); } } } } void svc_cleanup(void) { if(listener >= 0) { closesocket(listener); listener = -1; } if(regkey) { RegDeleteValue(regkey, "ipc-port"); RegCloseKey(regkey); regkey = NULL; } } static std::string reg_get_string(char const *name) { DWORD errnum, size = 0; errnum = RegQueryValueEx(regkey, name, NULL, NULL, NULL, &size); if(errnum != ERROR_SUCCESS) { die("Failed to query '%s' value: %s", name, w32_error(errnum)); } char *buf = new char[size+1]; buf[size] = '\0'; errnum = RegQueryValueEx(regkey, name, NULL, NULL, (BYTE*)buf, &size); if(errnum != ERROR_SUCCESS) { die("Failed to read '%s' value: %s", name, w32_error(errnum)); } std::string value = buf; delete buf; return value; } static DWORD reg_get_dword(char const *name) { DWORD errnum, value, size = sizeof(DWORD); errnum = RegQueryValueEx(regkey, name, NULL, NULL, (BYTE*)&value, &size); if(errnum != ERROR_SUCCESS) { die("Failed to read '%s' value: %s", name, w32_error(errnum)); } return value; } static void reg_set_string(char const *name, char const *value) { DWORD errnum = RegSetValueEx(regkey, name, 0, REG_SZ, (BYTE*)value, strlen(value)+1); if(errnum != ERROR_SUCCESS) { LogError("Failed to write '%s' value: %s", name, w32_error(errnum)); } } static void reg_set_dword(char const *name, DWORD value) { DWORD errnum = RegSetValueEx(regkey, name, 0, REG_DWORD, (BYTE*)&value, sizeof(DWORD)); if(errnum != ERROR_SUCCESS) { LogError("Failed to write '%s' value: %s", name, w32_error(errnum)); } } static void close_client(std::map::iterator client) { FD_CLR(client->first, &read_fds); closesocket(client->first); if(clients.size() == FD_SETSIZE) { FD_SET(listener, &read_fds); } clients.erase(client); } /* Parse a shutdown time and add it to an sdtimes vector */ static void add_sdtime(std::vector *sdtimes, char const *str) { struct sdtime sdtime; sdtime.days = atoi(str); str += strspn(str, "1234567890"); str += strspn(str, ";:"); sdtime.hour = atoi(str); str += strspn(str, "1234567890"); str += strspn(str, ";:"); sdtime.min = atoi(str); str += strspn(str, "1234567890"); str += strspn(str, ";:"); sdtime.sec = atoi(str); sdtimes->push_back(sdtime); } /* Calculate next shutdown time * Returns zero if no shutdown is scheduled */ static time_t calc_sdtime(time_t now) { time_t rval = 0, tval; struct tm *local = localtime(&now); int dbit = (1 << local->tm_wday); std::vector::iterator node = config.sdtimes.begin(); std::vector::iterator end = config.sdtimes.end(); while(node != end) { if(node->days & dbit) { local->tm_hour = node->hour; local->tm_min = node->min; local->tm_sec = node->sec; tval = mktime(local); if(tval > now && (tval < rval || !rval)) { rval = tval; } } node++; if(node == end && !rval) { int mdays = month_days(local->tm_year, local->tm_mon); if(local->tm_mday == mdays) { if(local->tm_mon == 11) { local->tm_year++; local->tm_mon = 0; }else{ local->tm_mon++; } local->tm_mday = 1; }else{ local->tm_mday++; } if(local->tm_wday == 6) { local->tm_wday = 0; }else{ local->tm_wday++; } node = config.sdtimes.begin(); } } return rval; } /* Send update packet to a single client */ static void update_client(int sockfd) { struct packet packet; packet.sdtime = sdtime; packet.warning = config.warning; packet.abort = config.abort; packet.delay = config.delay; unsigned int sent = 0; while(sent < sizeof(packet)) { int i = send(sockfd, (char*)&packet+sent, sizeof(packet)-sent, 0); if(i == -1) { LogError("Send to client failed: %s", w32_error(WSAGetLastError())); break; } sent += i; } } /* Send an update packet to all clients */ static void update_all(void) { struct packet packet; packet.sdtime = sdtime; packet.warning = config.warning; packet.abort = config.abort; packet.delay = config.delay; std::map::iterator client = clients.begin(); std::map::iterator end = clients.end(); while(client != end) { unsigned int sent = 0; while(sent < sizeof(packet)) { int i = send(client->first, (char*)&packet+sent, sizeof(packet)-sent, 0); if(i == -1) { LogError("Send to client failed: %s", w32_error(WSAGetLastError())); break; } sent += i; } client++; } } /* Load configuration from server and save it in the registry */ static void load_sconfig(void) { std::string ctext = dl_config(); if(ctext.empty()) { return; } char const *name = ctext.c_str(), *value; while(*name) { value = name+strcspn(name, "="); value += strspn(value, "="); if(strncmp(name, "warning=", 8) == 0) { config.warning = atoi(value); } if(strncmp(name, "abort=", 6) == 0) { config.abort = atoi(value); } if(strncmp(name, "delay=", 6) == 0) { config.delay = atoi(value); } if(strncmp(name, "sdtime=", 7) == 0) { add_sdtime(&(config.sdtimes), value); } name += strcspn(name, "\r\n"); name += strspn(name, "\r\n"); } std::vector::iterator sdtime = config.sdtimes.begin(); std::vector::iterator end = config.sdtimes.end(); std::string sdtimes; while(sdtime != end) { char buf[64]; sprintf(buf, "%d;%d:%d:%d,", sdtime->days, sdtime->hour, sdtime->min ,sdtime->sec); sdtimes.append(buf); sdtime++; } reg_set_dword("warning", config.warning); reg_set_dword("abort", config.abort); reg_set_dword("delay", config.delay); reg_set_string("sdtimes", sdtimes.c_str()); } #define DLC_ABORT() \ if(sockfd >= 0) { \ closesocket(sockfd); \ } \ if(addr) { \ freeaddrinfo(addr); \ } \ return ""; /* Connect to the server and download the configuration * Returns an empty string upon error */ static std::string dl_config(void) { int sockfd = -1, size; struct addrinfo hints, *addr = NULL; char port[32], buf[256]; std::string ctext; sockfd = socket(AF_INET, SOCK_STREAM, 0); if(sockfd == -1) { LogError("Failed to create socket: %s", w32_error(WSAGetLastError())); DLC_ABORT(); } memset((void*)&hints, 0, sizeof(hints)); hints.ai_family = AF_UNSPEC; hints.ai_socktype = SOCK_STREAM; hints.ai_protocol = IPPROTO_TCP; sprintf(port, "%d", config.server_port); int i = getaddrinfo(config.server_name.c_str(), port, &hints, &addr); if(i) { LogError("Server address lookup failed: %s", w32_error(i)); DLC_ABORT(); } struct linger lopt; lopt.l_onoff = 0; setsockopt(sockfd, SOL_SOCKET, SO_LINGER, (const char*)&lopt, sizeof(lopt)); if(connect(sockfd, addr->ai_addr, addr->ai_addrlen) == -1) { LogError("Error connecting to server: %s", w32_error(WSAGetLastError())); DLC_ABORT(); } gethostname(buf, 256); size = strlen(buf)+1; for(int sent = 0; sent < size; sent += i) { i = send(sockfd, buf+sent, size-sent, 0); if(i == -1) { LogError("Send to server failed: %s", w32_error(WSAGetLastError())); DLC_ABORT(); } } for(i = 1; i > 0;) { i = recv(sockfd, buf, 1024, 0); if(i == -1) { LogError("Recieve from server failed: %s", w32_error(WSAGetLastError())); DLC_ABORT(); } ctext.append(buf, i); } closesocket(sockfd); freeaddrinfo(addr); return ctext; }