from SQLLexer import * from SQLParser import * class SQLInjection(): def __init__(self): self.lexer = SQLLexer() self.lexer.build() self.parser = SQLParser() self.u_tok_counter = None self.s_tok_counter = None self.u_ast = None self.s_ast = None def validateLex(self, sample_sql, user_sql): self.s_tok_counter = self.lexer.getTokensHash() self.u_tok_counter = self.lexer.getTokensHash() for tok in self.lexer.tokenize(sample_sql): self.s_tok_counter[tok.type] += 1 for tok in self.lexer.tokenize(user_sql): self.u_tok_counter[tok.type] += 1 return self.s_tok_counter == self.u_tok_counter def getLastTokCounters(self): return self.s_tok_counter, self.u_tok_counter def validateParser(self, sample_sql, user_sql): self.s_ast = self.parser.parse(sample_sql) self.u_ast = self.parser.parse(user_sql) return self.s_ast == self.u_ast def getLastAsts(self): return self.s_ast, self.u_ast def print_ast(self, ast): Q = ast while len(Q) > 0: NQ = [] for node in Q: if type(node) == tuple: print node[0], for i in range(1, len(node)): NQ.append(node[i]) else: print node, Q = NQ print if __name__ == '__main__': sqlI = SQLInjection() # Test 1 print sqlI.validateLex("""select cat from dog where casa=1 ;""", """select cat from dog where casa=1 ;""") # Test 2 print sqlI.validateLex("""select cat from dog where casa=1 ;""", """select cat from dog where casa=1 and cat="miau" ;""") # Test 3 print sqlI.validateParser("""select cat from dog where casa=1 ;""", """select cat from dog where casa=1 ;""") # Test 2 print sqlI.validateParser("""select cat from dog where casa=1 ;""", """select cat from dog where casa=1 and cat="miau" ;""")