From 47e577ebaf07e4785a7a7cabfa402c36087fa1eb Mon Sep 17 00:00:00 2001 From: nico Date: Wed, 17 Jul 2019 17:56:15 +0200 Subject: Code cleanup (#2) * config class revamp + add get_at method + add set_at method + add unset_at method new config class to interactively get/set/unset config parameters from within other methods * move report function to new file + move all report functions to new report class the new report class works exactly like the internal one, but it is easier to maintain * + implement the new config functions + implement the changed report functions + update docstrings + update code comments * code cleanup * further code cleanup + add variable type to all methods * many indentation fixes --- config.py | 71 +++++++++++++++++----- main.py | 203 ++++++++++++++++---------------------------------------------- report.py | 133 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 240 insertions(+), 167 deletions(-) create mode 100644 report.py diff --git a/config.py b/config.py index fac5f9b..10e938f 100644 --- a/config.py +++ b/config.py @@ -2,22 +2,63 @@ import json import os -# filepath of the config.json in the project directory -path = os.path.dirname(__file__) -filepath = ("/".join([path, "config.json"])) -# try to read config.json if nonexistent create config.json an populate it -try: - with open(filepath, "r", encoding="utf-8") as f: - config = json.load(f) +class Config(object): + def __init__(self): + self.config = dict() + self.valid_config = bool -except FileNotFoundError: - with open(filepath, "w", encoding="utf-8") as f: - config = { - "name": "", - } - f.write(json.dumps(config)) + # filepath of the config.json in the project directory + self.path = os.path.dirname(__file__) + self.filepath = ('/'.join([self.path, 'config.json'])) + # load config + self.load() -class Config(object): - name = config["name"] + def load(self): + try: + # try to read config.json + with open(self.filepath, "r", encoding="utf-8") as f: + self.config = json.load(f) + + except FileNotFoundError: + # if file is absent create file + open(self.filepath, "w").close() + + except json.decoder.JSONDecodeError: + # config file is present but empty + pass + + def get_at(self, attrib: str): + """ + retrieve attribute from config file + :param attrib: keyword corresponding to keyword in config dictionary + :return: value of specified keyword or False if keyword is not present in dictionary + """ + if attrib in self.config: + # return corresponding attrib from config + return self.config[attrib] + else: + # if attrib is not present in config return False + self.config[attrib] = False + + def set_at(self, attrib: str, param): + """ + set attribute to parameter inside config file + :param attrib: keyword which should be updated/created in config dictionary + :param param: parameter the keyword should be updated to + """ + self.config[attrib] = param + + # save new attrib to file + with open(self.filepath, "w", encoding="utf-8") as f: + f.write(json.dumps(self.config, indent=4)) + + def unset_at(self, attrib: str): + """ + unset attribute inside config file + :param attrib: attribute which should be unset inside config file + """ + if attrib in self.config: + # only if attrib is actually present unset it + self.config.pop(attrib) diff --git a/main.py b/main.py index 7090eb5..d885b7a 100755 --- a/main.py +++ b/main.py @@ -6,12 +6,13 @@ import gzip import os import re import sqlite3 +import sys -import dns.resolver as dns import tabulate from defusedxml import ElementTree from config import Config +from report import ReportDomain class AbuseReport: @@ -22,15 +23,15 @@ class AbuseReport: self.domain = arguments.domain self.report = arguments.report self.path = os.path.dirname(__file__) + self.config = Config() self.conn = sqlite3.connect("/".join([self.path, "spam.db"])) self.jid_pattern = re.compile("^(?:([^\"&'/:<>@]{1,1023})@)?([^/@]{1,1023})(?:/(.{1,1023}))?$") self.message_pattern = re.compile(r'', re.DOTALL) def main(self): - """ - method deciding over which action to take - """ + """main method guiding the actions to take""" + if self.infile is None: # infile unset -> report top10 self.egest() @@ -43,46 +44,48 @@ class AbuseReport: self.conn.close() def egest(self): - """ - egest method - if specific domain is supplied return only those results - in any other case return top 10 table - """ + """egest method returning the database results""" + # init result list result = list() - # if domain is specified return info for that domain + # if a domain is specified return only that info if self.domain is not None: - result = list() # iterate over all domains supplied for domain in self.domain: - query = self.conn.execute('''SELECT COUNT(*) AS messages,COUNT(DISTINCT user) AS bots,domain, - MIN(ts) AS first,MAX(ts) AS last FROM spam WHERE domain = :domain;''', - {"domain": domain}).fetchall() + sql_query = self.conn.execute('''SELECT COUNT(*) AS messages,COUNT(DISTINCT user) AS bots,domain, MIN(ts) + AS first,MAX(ts) AS last FROM spam WHERE domain = :domain;''',{"domain": domain}).fetchall() - # if specified domain is not listed yet, the resulting table would miss the domain name - # this ugle tuple 2 list swap prevents this behaviour - temp = list(query[0]) + # if specified domain is not listed yet, the resulting table will not show the domain name + # this ugly tuple 2 list swap prevents this + temp = list(sql_query[0]) if temp[2] is None: temp[2] = domain - query[0] = tuple(temp) + sql_query[0] = tuple(temp) - # extend result table - result.extend(query) + # extend result tables + result.extend(sql_query) # generate report if enabled if self.report: - self.gen_report(domain, query) + self.gen_report(domain, sql_query) + else: - # in any other case return top 10 - result = self.conn.execute('''SELECT COUNT(*) AS messages,COUNT(DISTINCT user) AS bots,domain AS domain - FROM spam GROUP BY domain ORDER BY 1 DESC LIMIT 10;''') + # in any other case return top 10 view + if self.config.get_at("top10_view"): + result = self.conn.execute('''SELECT * FROM "top10"''').fetchall() + else: + result = self.conn.execute('''SELECT COUNT(*) AS messages,COUNT(DISTINCT user) AS bots,domain AS domain + FROM spam GROUP BY domain ORDER BY 1 DESC LIMIT 10''').fetchall() + + # tabelize data + spam_table = tabulate.tabulate(result, headers=["messages", "bots", "domain", "first seen", "last seen"], + tablefmt="github") - # format data as table - table = tabulate.tabulate(result, headers=["messages", "bots", "domain","first seen", "last seen"], - tablefmt="orgtbl") - print(table) + # output to stdout + output = "\n\n".join([spam_table]) + print(output, file=sys.stdout) def ingest(self): """ @@ -101,7 +104,7 @@ class AbuseReport: except FileNotFoundError as err: content = "" - print(err) + print(err, file=sys.stderr) # if magic number is present decompress and decode file if content.startswith(magic_number): @@ -112,38 +115,27 @@ class AbuseReport: # automated run None catch if content is not None: - self.parse(content) - - def parse(self, infile): - """ - method to parse xml messages - :type infile: str - :param infile: string containing xml stanzas - """ - log = re.findall(self.message_pattern, infile) + log = re.findall(self.message_pattern, content) - if log is not None: - self.db_import(log) + if log is not None: + self.db_import(log) - def db_import(self, message_log): + def db_import(self, message_log: list): """ import xml stanzas into database - :type infile: str - :param message_log: xml messages + :param message_log: list of xml messages """ - self.conn.execute('''CREATE TABLE IF NOT EXISTS "spam" ("user" TEXT, "domain" TEXT, "ts" TEXT, "message" TEXT, - PRIMARY KEY("domain","ts"));''') - for message in message_log: message_parsed = ElementTree.fromstring(message) - # parse from tag + # parse 'from' tag spam_from = message_parsed.get('from') match = self.jid_pattern.match(spam_from) (node, domain, resource) = match.groups() # stamp all_delay_tags = message_parsed.findall('.//{urn:xmpp:delay}delay') + spam_time = None for tag in all_delay_tags: if "@" in tag.get("from"): continue @@ -157,27 +149,29 @@ class AbuseReport: # format sql try: - self.conn.execute('INSERT INTO spam VALUES(:user, :domain, :spam_time, :spam_body);', - {"user": node, "domain": domain, "spam_time": spam_time, "spam_body": spam_body}) + self.conn.execute('''INSERT INTO spam VALUES(:user, :domain, :spam_time, :spam_body);''', + {"user": node, "domain": domain, "spam_time": spam_time, "spam_body": spam_body}) except sqlite3.IntegrityError: pass finally: self.conn.commit() - def gen_report(self, domain, query): + def gen_report(self, domain: str, query: list): """ method generating the report files - :type domain: str :param domain: string containing a domain name - :param query: sqlite cursor object containing the query results for the specified domain + :param query: list of tuples containing the query results for the specified domain/s """ + # init report class + report = ReportDomain(self.config, self.conn) + try: # open abuse report template file with open("/".join([self.path, "template/abuse-template.txt"]), "r", encoding="utf-8") as template: report_template = template.read() except FileNotFoundError as err: - print(err) + print(err, file=sys.stderr) exit(1) # current date @@ -190,118 +184,23 @@ class AbuseReport: # write report files with open("/".join([self.path, "report", report_filename]), "w", encoding="utf-8") as report_out: - content = self.report_template(report_template, domain, query) + content = report.template(report_template, domain, query) report_out.write(content) with open("/".join([self.path, "report", jids_filename]), "w", encoding="utf-8") as report_out: - content = self.report_jids(domain) + content = report.jids(domain) report_out.write(content) with open("/".join([self.path, "report", logs_filename]), "w", encoding="utf-8") as report_out: - content = self.report_logs(domain) + content = report.logs(domain) report_out.write(content) - def report_template(self, template, domain, query): - """ - method to collect and format the template file to the final abuse report - :type template: str - :type domain: str - :param template: string containing the abuse report template - :param domain: string containing a domain name - :param query: sqlite cursor object containing the query results for the specified domain - :return: string containing the fully formatted abuse report - """ - name = Config.name - - # lookup srv and domain info - info = self.srvlookup(domain) - srv = info[0]["host"] - ips = "".join(info[0]["ip"]) - summary = tabulate.tabulate(query, headers=["messages", "bots", "domain","first seen", "last seen"], - tablefmt="orgtbl") - - report_out= template.format(name=name, domain=domain, srv=srv, ips=ips, summary=summary) - - return report_out - - def report_jids(self, domain): - """ - method to collect all involved jids from the database - :type domain: str - :param domain: string containing a domain name - :return: formatted string containing the result - """ - - jids = self.conn.execute('''SELECT user || '@' || domain as jid FROM spam WHERE domain=:domain GROUP BY user - ORDER BY 1;''', {"domain": domain}).fetchall() - - return tabulate.tabulate(jids, tablefmt="plain") - - def report_logs(self, domain): - """ - method to collect all messages grouped by frequency - :type domain: str - :param domain: string containing a domain name - :return: formatted string containing the result - """ - logs = self.conn.execute('''SELECT char(10)||MIN(ts)||' - '||MAX(ts)||char(10)||COUNT(*)||' messages:'||char(10) - ||'========================================================================'||char(10)||message||char(10)|| - '========================================================================' FROM spam WHERE domain=:domain - GROUP BY message ORDER BY COUNT(*) DESC LIMIT 10;''', {"domain": domain}).fetchall() - - return tabulate.tabulate(logs, tablefmt="plain") - - def srvlookup(self, domain): - """ - srv lookup method for the domain provided, if no srv record is found the base domain is used - :type domain: str - :param domain: provided domain to query srv records for - :return: sorted list of dictionaries containing host and ip info - """ - # srv - query = '_xmpp-client._tcp.{}'.format(domain) - - try: - srv_records = dns.query(query, 'SRV') - - except (dns.NXDOMAIN, dns.NoAnswer): - # catch NXDOMAIN and NoAnswer tracebacks - srv_records = None - - # extract record - results = list() - - if srv_records is not None: - # extract all available records - for record in srv_records: - info = dict() - - # gather necessary info from srv records - info["host"] = str(record.target).rstrip('.') - info["weight"] = record.weight - info["priority"] = record.priority - info["ip"] = [ip.address for ip in dns.query(info["host"], "A")] - results.append(info) - - # return list sorted by priority and weight - return sorted(results, key=lambda i: (i['priority'], i["weight"])) - - # prevent empty info when srv records are not present - info = dict() - - # gather necessary info from srv records - info["host"] = domain - info["ip"] = [ip.address for ip in dns.query(info["host"], "A")] - results.append(info) - - return results - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('-in', '--infile', nargs='+', help='set path to input file', dest='infile') parser.add_argument('-d', '--domain', action='append', help='specify report domain', dest='domain') - parser.add_argument('-r', '--report', action='store_true', help='toggle report output to file', dest='report') + parser.add_argument('-r', '--report', action='store_true', help='toggle report output to file', dest='report') args = parser.parse_args() # run diff --git a/report.py b/report.py new file mode 100644 index 0000000..6357db5 --- /dev/null +++ b/report.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- +import dns.resolver as dns +import tabulate + + +class ReportDomain: + def __init__(self, config, conn): + """ + :param config: configuration object + :param conn: sqlite connection object + """ + self.config = config + self.conn = conn + + def template(self, template: str, domain: str, query: list): + """ + method to retrieve and format the template file + :param template: string containing the abuse report template + :param domain: string containing a domain name + :param query: list of tuples containing the query results for the specified domain/s + :return: string containing the fully formatted abuse report + """ + name = self.config.get_at("name") + + # lookup and format srv target and ip + srv, ips = self.srv(domain) + summary = tabulate.tabulate(query, headers=["messages", "bots", "domain", "first seen", "last seen"], + tablefmt="github") + + report_out = template.format(name=name, domain=domain, srv=srv, ips=ips, summary=summary) + + return report_out + + def jids(self, domain: str): + """ + method to collect all involved jids from the database + :param domain: string containing a domain name + :return: formatted result string + """ + + jids = self.conn.execute('''SELECT user || '@' || domain as jid FROM spam WHERE domain=:domain GROUP BY user + ORDER BY 1;''', {"domain": domain}).fetchall() + + return tabulate.tabulate(jids, tablefmt="plain") + + def logs(self, domain: str): + """ + method to collect all messages grouped by frequency + :param domain: string containing a domain name + :return: formatted string containing the result + """ + logs = self.conn.execute('''SELECT CHAR(10) || MIN(ts) || ' - ' || MAX(ts) || char(10) || COUNT(*) || + 'messages:' || char(10) ||'========================================================================' || + char(10) || message || char(10) || '========================================================================' + FROM spam WHERE domain=:domain GROUP BY message ORDER BY COUNT(*) DESC LIMIT 10;''', {"domain": domain}).fetchall() + + return tabulate.tabulate(logs, tablefmt="plain") + + def srv(self, domain: str, only_highest: bool = True): + info = self._srvlookup(domain) + + if only_highest: + target = info[0]["host"] + ips = info[0]["ip"] + + return target, ips + + return info + + @staticmethod + def _getip(domain: str): + """ + method to query the a / aaaa record of a specified domain + :param domain: valid domain target + :return: filtered list of all a/ aaaa records + """ + # init records + a, a4 = None, None + + try: + # query and join both a and aaaa records + a = ", ".join([ip.address for ip in dns.query(domain, "A")]) + a4 = ", ".join([ip.address for ip in dns.query(domain, "AAAA")]) + + except (dns.NXDOMAIN, dns.NoAnswer): + # catch NXDOMAIN and NoAnswer tracebacks not really important + pass + + return list(filter(None.__ne__, [a, a4])) + + def _srvlookup(self, domain: str): + """ + srv lookup method for the domain provided, if no srv record is found the base domain is used + :param domain: provided domain to query srv records for + :return: sorted list of dictionaries containing host and ip info + """ + # init + results = list() + srv_records = None + + try: + srv_records = dns.query('_xmpp-client._tcp.{}'.format(domain), 'SRV') + + except (dns.NXDOMAIN, dns.NoAnswer): + # catch NXDOMAIN and NoAnswer tracebacks + pass + + # extract record + if srv_records is not None: + # extract all available records + for record in srv_records: + info = dict() + + # gather necessary info from srv records + info["host"] = record.target.to_text().rstrip('.') + info["port"] = record.port + info["weight"] = record.weight + info["priority"] = record.priority + info["ip"] = ", ".join(self._getip(record.target.to_text())) + results.append(info) + + # return list sorted by priority and weightre + return sorted(results, key=lambda i: (i['priority'], i["weight"])) + + # prevent empty info when srv records are not present + info = dict() + + # gather necessary info from srv records + info["host"] = domain + info["ip"] = ", ".join(self._getip(domain)) + results.append(info) + + return results -- cgit v1.2.3-18-g5258