123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
-
- from SQLLexer import *
- from SQLParser import *
-
-
- class SQLInjection():
-
- def __init__(self):
- self.lexer = SQLLexer()
- self.lexer.build()
-
- self.parser = SQLParser()
-
- self.u_tok_counter = None
- self.s_tok_counter = None
-
- self.u_ast = None
- self.s_ast = None
-
-
- def validateLex(self, sample_sql, user_sql):
-
- self.s_tok_counter = self.lexer.getTokensHash()
- self.u_tok_counter = self.lexer.getTokensHash()
-
- for tok in self.lexer.tokenize(sample_sql):
- self.s_tok_counter[tok.type] += 1
-
- for tok in self.lexer.tokenize(user_sql):
- self.u_tok_counter[tok.type] += 1
-
- return self.s_tok_counter == self.u_tok_counter
-
- def getLastTokCounters(self):
- return self.s_tok_counter, self.u_tok_counter
-
- def validateParser(self, sample_sql, user_sql):
-
- self.s_ast = self.parser.parse(sample_sql)
- self.u_ast = self.parser.parse(user_sql)
-
- return self.s_ast == self.u_ast
-
- def getLastAsts(self):
- return self.s_ast, self.u_ast
-
- def print_ast(self, ast):
- Q = ast
- while len(Q) > 0:
- NQ = []
- for node in Q:
- if type(node) == tuple:
- print node[0],
- for i in range(1, len(node)):
- NQ.append(node[i])
- else:
- print node,
- Q = NQ
- print
-
- if __name__ == '__main__':
-
- sqlI = SQLInjection()
-
- # Test 1
- print sqlI.validateLex("""select cat from dog where casa=1 ;""", """select cat from dog where casa=1 ;""")
-
- # Test 2
- print sqlI.validateLex("""select cat from dog where casa=1 ;""", """select cat from dog where casa=1 and cat="miau" ;""")
-
- # Test 3
- print sqlI.validateParser("""select cat from dog where casa=1 ;""", """select cat from dog where casa=1 ;""")
-
- # Test 2
- print sqlI.validateParser("""select cat from dog where casa=1 ;""", """select cat from dog where casa=1 and cat="miau" ;""")
|