shithub: pokecrystal

Download patch

ref: 33d8c7a11711258843ab5abd6944453e44897b9c
parent: 3bd84c1dac0dc7085287bd6a7c822dfa2663cf71
author: Bryan Bishop <[email protected]>
date: Sat Mar 24 17:34:19 EDT 2012

wonderful world of testing

--- a/extras/crystal.py
+++ b/extras/crystal.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 #utilities to help disassemble pokémon crystal
-import sys, os, inspect, md5
+import sys, os, inspect, md5, json
 from copy import copy
 
 #for IntervalMap
@@ -10,6 +10,9 @@
 #for testing all this crap
 import unittest2 as unittest
 
+if not hasattr(json, "dumps"):
+    json.dumps = json.write
+
 #table of pointers to map groups
 #each map group contains some number of map headers
 map_group_pointer_table = 0x94000
@@ -4435,13 +4438,22 @@
 
 def process_incbins():
     "parse incbin lines into memory"
-    global incbins
-    incbins = {} #reset
+    global asm, incbin_lines, processed_incbins
+    #load asm if it isn't ready yet
+    if asm == [] or asm == None:
+        load_asm()
+    #get a list of incbins if that hasn't happened yet
+    if incbin_lines == [] or incbin_lines == None:
+        isolate_incbins()
+    #reset the global that this function creates
+    processed_incbins = {}
+    #for each incbin..
     for incbin in incbin_lines:
+        #reset this entry
         processed_incbin = {}
-
+        #get the line number from the global asm line list
         line_number = asm.index(incbin)
-
+        #forget about all the leading characters
         partial_start = incbin[21:]
         start = partial_start.split(",")[0].replace("$", "0x")
         start = eval(start)
@@ -4456,20 +4468,19 @@
         end = start + interval
         end_hex = hex(end).replace("0x", "$")
 
-        processed_incbin = {
-                            "line_number": line_number,
+        processed_incbin = {"line_number": line_number,
                             "line": incbin,
                             "start": start,
                             "interval": interval,
-                            "end": end,
-                           }
-
+                            "end": end, }
         #don't add this incbin if the interval is 0
         if interval != 0:
             processed_incbins[line_number] = processed_incbin
+    return processed_incbins
 
 def reset_incbins():
     "reset asm before inserting another diff"
+    global asm, incbin_lines, processed_incbins
     asm = None
     incbin_lines = []
     processed_incbins = {}
@@ -4580,7 +4591,7 @@
     #confirm it's working
     if do_compile:
         try:
-            subprocess.check_call("cd ../; make clean; LC_CTYPE=C make", shell=True)
+            subprocess.check_call("cd ../; make clean; make", shell=True)
             return True
         except Exception, exc:
             if try_fixing:
@@ -4592,13 +4603,6 @@
     where f(item) == True."""
     return next((i for i in xrange(len(seq)) if f(seq[i])), None)
 
-def is_probably_pointer(input):
-    try:
-        blah = int(input, 16)
-        return True
-    except:
-        return False
-
 def analyze_intervals():
     """find the largest baserom.gbc intervals"""
     global asm, processed_incbins
@@ -4614,10 +4618,11 @@
         results.append(processed_incbins[key])
     return results
 
-def write_all_labels(all_labels):
-    fh = open("labels.json", "w")
+def write_all_labels(all_labels, filename="labels.json"):
+    fh = open(filename, "w")
     fh.write(json.dumps(all_labels))
     fh.close()
+    return True
 
 def remove_quoted_text(line):
     """get rid of content inside quotes
@@ -4632,9 +4637,12 @@
         line = line[0:first] + line[second+1:]
     return line
 
-def line_has_comment_address(line, returnable={}):
+def line_has_comment_address(line, returnable={}, bank=None):
     """checks that a given line has a comment
-    with a valid address"""
+    with a valid address, and returns the address in the object.
+    Note: bank is required if you have a 4-letter-or-less address,
+    because otherwise there is no way to figure out which bank
+    is curretly being scanned."""
     #first set the bank/offset to nada
     returnable["bank"] = None
     returnable["offset"] = None
@@ -4658,7 +4666,7 @@
     if line[-2:] == "; ":
         return False
     #and multiple whitespace doesn't count either
-    line = line.rstrip(" ")
+    line = line.rstrip(" ").lstrip(" ")
     if line[-1] == ";":
         return False
     #there must be more content after the semicolon
@@ -4675,7 +4683,7 @@
         token = comment.split(" ")[0]
     if token in ["0x", "$", "x", ":"]:
         return False
-    bank, offset = None, None
+    offset = None
     #process a token with a A:B format
     if ":" in token: #3:3F0A, $3:$3F0A, 0x3:0x3F0A, 3:3F0A
         #split up the token
@@ -4717,10 +4725,8 @@
     elif "$" in token and not "x" in token:
         token = token.replace("$", "0x")
         offset = int(token, 16)
-        bank = calculate_bank(offset)
     elif "0x" in token and not "$" in token:
         offset = int(token, 16)
-        bank = calculate_bank(offset)
     else: #might just be "1" at this point
         token = token.lower()
         #check if there are bad characters
@@ -4728,9 +4734,10 @@
             if c not in valid:
                 return False
         offset = int(token, 16)
-        bank = calculate_bank(offset)
     if offset == None and bank == None:
         return False
+    if bank == None:
+        bank = calculate_bank(offset)
     returnable["bank"] = bank
     returnable["offset"] = offset
     returnable["address"] = calculate_pointer(offset, bank=bank)
@@ -4773,7 +4780,7 @@
     return without_addresses
 
 label_errors = ""
-def get_labels_between(start_line_id, end_line_id, bank_id):
+def get_labels_between(start_line_id, end_line_id, bank):
     labels = []
     #label = {
     #   "line_number": 15,
@@ -4782,6 +4789,8 @@
     #   "offset": 0x5315,
     #   "address": 0x75315,
     #}
+    if asm == None:
+        load_asm()
     sublines = asm[start_line_id : end_line_id + 1]
     for (current_line_offset, line) in enumerate(sublines):
         #skip lines without labels
@@ -4794,7 +4803,7 @@
         #setup a place to store return values from line_has_comment_address
         returnable = {}
         #get the address from the comment
-        has_comment = line_has_comment_address(line, returnable=returnable)
+        has_comment = line_has_comment_address(line, returnable=returnable, bank=bank)
         #skip this line if it has no address in the comment
         if not has_comment: continue
         #parse data from line_has_comment_address
@@ -4813,7 +4822,7 @@
         labels.append(label)
     return labels
 
-def scan_for_predefined_labels():
+def scan_for_predefined_labels(debug=False):
     """looks through the asm file for labels at specific addresses,
     this relies on the label having its address after. ex:
 
@@ -4825,8 +4834,9 @@
     addresses, but faster to write this script. rgbasm would be able
     to grab all label addresses better than this script..
     """
-    bank_intervals = {}
+    global all_labels
     all_labels = []
+    bank_intervals = {}
 
     #figure out line numbers for each bank
     for bank_id in range(0x7F+1):
@@ -4836,29 +4846,34 @@
             abbreviation = "0"
             abbreviation_next = "1"
 
+        #calculate the start/stop line numbers for this bank
         start_line_id = index(asm, lambda line: "\"bank" + abbreviation + "\"" in line)
-
-        if bank_id != 0x2c:
+        if bank_id != 0x7F:
             end_line_id = index(asm, lambda line: "\"bank" + abbreviation_next + "\"" in line)
+            end_line_id += 1
         else:
             end_line_id = len(asm) - 1
 
-        print "bank" + abbreviation + " starts at " + str(start_line_id) + " to " + str(end_line_id)
-
-        bank_intervals[bank_id] = {
-                                    "start": start_line_id,
-                                    "end": end_line_id,
-                                  }
+        if debug:
+            output = "bank" + abbreviation + " starts at "
+            output += str(start_line_id)
+            output += " to "
+            output += str(end_line_id)
+            print output
+        
+        #store the start/stop line number for this bank
+        bank_intervals[bank_id] = {"start": start_line_id,
+                                   "end": end_line_id,}
+    #for each bank..
     for bank_id in bank_intervals.keys():
+        #get the start/stop line number
         bank_data = bank_intervals[bank_id]
-
         start_line_id = bank_data["start"]
         end_line_id   = bank_data["end"]
-
+        #get all labels between these two lines
         labels = get_labels_between(start_line_id, end_line_id, bank_id)
         #bank_intervals[bank_id]["labels"] = labels
         all_labels.extend(labels)
-
     write_all_labels(all_labels)
     return all_labels
 
@@ -5112,6 +5127,9 @@
         self.assertTrue(x(";3:FFAA"))
         self.assertFalse(x('hello world "how are you today;0x1"'))
         self.assertTrue(x('hello world "how are you today:0x1";1'))
+        returnable = {}
+        self.assertTrue(x("hello_world: ; 0x4050", returnable=returnable, bank=5))
+        self.assertTrue(returnable["address"] == 0x14050)
     def test_line_has_label(self):
         x = line_has_label
         self.assertTrue(x("hi:"))
@@ -5135,6 +5153,88 @@
         labels = find_labels_without_addresses()
         self.failUnless(len(labels) == 0)
         asm = None
+    def test_get_labels_between(self):
+        global asm
+        x = get_labels_between#(start_line_id, end_line_id, bank)
+        asm = ["HelloWorld: ;1",
+               "hi:",
+               "no label on this line",
+              ]
+        labels = x(0, 2, 0x12)
+        self.assertEqual(len(labels), 1)
+        self.assertEqual(labels[0]["label"], "HelloWorld")
+        del asm
+    def test_scan_for_predefined_labels(self):
+        #label keys: line_number, bank, label, offset, address
+        load_asm()
+        all_labels = scan_for_predefined_labels()
+        label_names = [x["label"] for x in all_labels]
+        self.assertIn("GetFarByte", label_names)
+        self.assertIn("AddNTimes", label_names)
+        self.assertIn("CheckShininess", label_names)
+    def test_write_all_labels(self):
+        """dumping json into a file"""
+        filename = "test_labels.json"
+        #remove the current file
+        if os.path.exists(filename):
+            os.system("rm " + filename)
+        #make up some labels
+        labels = []
+        #fake label 1
+        label = {"line_number": 5, "bank": 0, "label": "SomeLabel", "address": 0x10}
+        labels.append(label)
+        #fake label 2
+        label = {"line_number": 15, "bank": 2, "label": "SomeOtherLabel", "address": 0x9F0A}
+        labels.append(label)
+        #dump to file
+        write_all_labels(labels, filename=filename)
+        #open the file and read the contents
+        file_handler = open(filename, "r")
+        contents = file_handler.read()
+        file_handler.close()
+        #parse into json
+        obj = json.read(contents)
+        #begin testing
+        self.assertEqual(len(obj), len(labels))
+        self.assertEqual(len(obj), 2)
+        self.assertEqual(obj, labels)
+    def test_isolate_incbins(self):
+        global asm
+        asm = ["123", "456", "789", "abc", "def", "ghi",
+               'INCBIN "baserom.gbc",$12DA,$12F8 - $12DA',
+               "jkl",
+               'INCBIN "baserom.gbc",$137A,$13D0 - $137A']
+        lines = isolate_incbins()
+        self.assertIn(asm[6], lines)
+        self.assertIn(asm[8], lines)
+        for line in lines:
+            self.assertIn("baserom", line)
+    def test_process_incbins(self):
+        global incbin_lines, processed_incbins, asm
+        incbin_lines = ['INCBIN "baserom.gbc",$12DA,$12F8 - $12DA',
+                        'INCBIN "baserom.gbc",$137A,$13D0 - $137A']
+        asm = copy(incbin_lines)
+        asm.insert(1, "some other random line")
+        processed_incbins = process_incbins()
+        self.assertEqual(len(processed_incbins), len(incbin_lines))
+        self.assertEqual(processed_incbins[0]["line"], incbin_lines[0])
+        self.assertEqual(processed_incbins[2]["line"], incbin_lines[1])
+    def test_reset_incbins(self):
+        global asm, incbin_lines, processed_incbins
+        #temporarily override the functions
+        global load_asm, isolate_incbins, process_incbins
+        temp1, temp2, temp3 = load_asm, isolate_incbins, process_incbins
+        def load_asm(): pass
+        def isolate_incbins(): pass
+        def process_incbins(): pass
+        #call reset
+        reset_incbins()
+        #check the results
+        self.assertTrue(asm == [] or asm == None)
+        self.assertTrue(incbin_lines == [])
+        self.assertTrue(processed_incbins == {})
+        #reset the original functions
+        load_asm, isolate_incbins, process_incbins = temp1, temp2, temp3
 class TestMapParsing(unittest.TestCase):
     #def test_parse_warp_bytes(self):
     #    pass #or raise NotImplementedError, bryan_message
--- a/main.asm
+++ b/main.asm
@@ -24,7 +24,7 @@
 
 INCBIN "baserom.gbc",$305d,$30fe-$305d
 
-AddNTimes ; 0x30fe
+AddNTimes: ; 0x30fe
 	and a
 	ret z
 .loop