#!/usr/bin/env python3

# FIXME:
#
#   1. It might be better if we taught all the objects to compare themselves?
#      But would that make the ordering less clear?
#
#   2. In general this script is pretty messy, but it’s an OK to get things
#      rolling, I think.

import ast
import os
import os.path
import pkgutil
import re
import sys

# This script depends on Python 3.9 for (1) str.removesuffix, which is easy to
# work around, and (2) attribute ast.AST.end_lineno, which is not. I’m not
# sure how to address this in a tidy way -- ideally the test would skip, but
# how do we know at the shell level whether we have a good Python? -- but
# simply succeeding if Python is too old is very easy. We have enough Python
# ≥3.9 testing happening that I think it’s unlikely we’ll miss something.
if (sys.version_info < (3, 9)):
   print("Python is too old for this script: %s" % sys.version)
   print("exiting successfully")
   sys.exit(0)

path = sys.argv[1]
print("analyzing %s" % path)

CLASS_ORDER = { n:o for (o,n) in enumerate((ast.Import,
                                            ast.Assign,
                                            ast.AnnAssign,
                                            ast.AugAssign,
                                            ast.ClassDef,
                                            ast.FunctionDef)) }

SECTION_ORDER = { n:o for (o,n) in enumerate(("Enums",
                                              "Constants",
                                              "Globals",
                                              "Exceptions",
                                              "Main",
                                              "Functions",
                                              "Supporting classes",
                                              "Core classes",
                                              "Classes")) }

# Python ≥3.10 has this as sys.stdlib_module_names without the fooling around.
STDLIB_MODULES = (  {     m.name
                      for m in pkgutil.iter_modules()
                      if  re.search(r"/python3.\d+", m.module_finder.path) }
                  | set(sys.builtin_module_names))

CH_MODULES = {     f.removesuffix(".py")
               for f in os.listdir(os.path.dirname(path))
               if  (f.endswith(".py") and f != "charliecloud.py") }

text = open(path).read()
lines = ["# DUMMY LINE TO MAKE IT 1-INDEXED"] + text.splitlines()
tree = ast.parse(text, filename=path)
error_ct = 0
classes_seen = set()


class Section(ast.AST):

   def __init__(self, lineno, end_lineno, name):
      self.lineno = lineno
      self.end_lineno = end_lineno
      self.name = name

   def __str__(self):
      return "<Section: %s>" % self.name

   @classmethod
   def parse(cls, lineno, end_lineno):
      lineno_found = lineno - 1
      for line in lines[lineno:end_lineno+1]:
         lineno_found += 1
         m = re.search(r"^ *##+ (.+) ##+", line)
         if (m is not None):
            return cls(lineno_found, lineno_found, m[1])
      return None


def FAIL(lineno, msg):
   print("😭 %d: %s" % (lineno, msg))
   global error_ct
   error_ct += 1

def FAILO(before, after):
   FAIL(before.lineno, (  "%s precedes %s (line %d) but should follow"
                        % (before.name, after.name, after.lineno)))

def inherits(child, parent):
   "Return True if child inherits from parent (perhaps indirectly), else False."
   # FIXME: Considers only first parent.
   if (not (    isinstance(child, ast.ClassDef)
            and isinstance(parent, ast.ClassDef))):
      return False  # not both classes
   while True:
      if (len(child.bases) < 1):
         return False  # no base clase
      elif (not isinstance(child.bases[0], ast.Name)):
         return False  # dot in name so can’t be in this file
      elif (child.bases[0].id == parent.name):
         return True   # found a match
      else:
         # Move child up one generation. All we have here is the name, so
         # search all the module statements for a matching class. 🤪
         child_new = None
         for stmt in tree.body:
            if (    isinstance(stmt, ast.ClassDef)
                and stmt.name == child.bases[0].id):
               child_new = stmt
               break
         if (child_new is not None):
            child = child_new
         else:
            return False  # not found, so parent must be in another file


def parse(statements_raw):
   statements = [Section(0, 0, "UNNAMED")]
   # Add in section comments.
   for i in range(len(statements_raw) - 1):
      s_cur = statements_raw[i]
      s_next = statements_raw[i+1]
      statements.append(s_cur)
      if (s_cur.end_lineno + 1 != s_next.lineno):
         # Gap between statements parsed by ast.parse(). Maybe it’s a section
         # comment?
         section = Section.parse(s_cur.end_lineno + 1, s_next.lineno - 1)
         if (section is not None):
            statements.append(section)
   try:
      statements.append(s_next)
   except UnboundLocalError:
      pass  # empty body
   # Remove statement types exempt from ordering. Iterate backwards so we can
   # modify in-place.
   for i in reversed(range(len(statements))):
      if (    statements[i].__class__ not in CLASS_ORDER
          and statements[i].__class__ != Section):
         del statements[i]
   # Remove statements special-case exempt from ordering.
   for i in reversed(range(len(statements))):
      if (re.search(r"# +👻", lines[statements[i].lineno]) is not None):
         if (isinstance(statements[i], ast.ClassDef)):
            FAIL("%s: no exemptions for classes" % statements[i].name)
         else:
            del statements[i]
   # Set statement names.
   for stmt in statements:
      if (isinstance(stmt, ast.Import)):
         if (len(stmt.names) != 1):
            FAIL(stmt.lineno, "too many imports on same line")
         stmt.name = stmt.names[0].name
      elif (isinstance(stmt, ast.Assign)):
         if (isinstance(stmt.targets[0], ast.Name)):
            stmt.name = stmt.targets[0].id
         elif (isinstance(stmt.targets[0], (ast.Tuple, ast.List))):
            stmt.name = ",".join(i.id for i in stmt.targets[0].elts)
      elif (isinstance(stmt, (Section, ast.ClassDef, ast.FunctionDef))):
         pass  # already has name attribute
      else:
         assert False, "invalid statement type: %s" % type(stmt)
   # Done.
   return statements

def sort_key(stmt):
   """Return a tuple for sort order:

        (int: statement type,
         int: statement subtype,
         str: type-specific,
         str: object name)"""
   ret = list()
   ret.append(CLASS_ORDER[stmt.__class__])
   if (isinstance(stmt, ast.Import)):
      name = stmt.name.split(".")[0]
      if (name == "charliecloud"):
         ret.append(3)
      elif (name in STDLIB_MODULES):
         ret.append(1)
      elif (name in CH_MODULES):
         ret.append(4)
      else:  # neither standard library nor Charliecloud
         ret.append(2)
      ret.append(stmt.name)
   elif (isinstance(stmt, ast.Assign)):
      pass  # all assignments are equal
   elif (isinstance(stmt, ast.FunctionDef)):
      dl = stmt.decorator_list
      try:
         #print(ast.dump(stmt, indent=2))
         try:
            decorator = dl[0].id
         except AttributeError:
            if (isinstance(dl[0], ast.Attribute)):
                # dotted decorator
                if (dl[0].attr == "setter"):
                   decorator = "property"  # setter is close enough
                else:
                   decorator = dl[0].value.id + "." + dl[0].attr
      except IndexError:
         decorator = None
      if (stmt.name == "__init__"):
         ret.append(1)
      elif (re.search(r"^__.+__$", stmt.name)):
         ret.append(4)
      elif (decorator == "staticmethod"):
         ret.append(2)
      elif (decorator == "classmethod"):
         ret.append(3)
      elif (decorator == "property"):
         ret.append(5)
      else:
         ret.append(6)
      ret.append(stmt.name)
   elif (isinstance(stmt, ast.ClassDef)):
      # NOTE: This does *not* consider inheritance relationships. That is
      # special cased in the main loop.
      ret.append(stmt.name)
   else:
      assert False, "unreachable code reached"
   return tuple(ret)

def validate(statements):
   statements = parse(statements)

   # validate section order
   sections = [stmt for stmt in statements
               if isinstance(stmt, Section) and stmt.name in SECTION_ORDER]
   for i in range(len(sections) - 1):
      before = sections[i]
      for j in range(i + 1, len(sections)):
         after = sections[j]
      if (SECTION_ORDER[before.name] >= SECTION_ORDER[after.name]):
         FAILO(before, after)
      if (   (before.name == "Classes" and "classes" in after.name.lower())
          or (after.name == "Classes" and "classes" in before.name.lower())):
         FAIL(before.lineno, ("§%s not allowed if §%s (line %d) present"
                              % (before.name, after.name, after.lineno)))

   # build statements within sections
   sections = dict()  # retains insertion order ≥3.6
   for stmt in statements:
      if (isinstance(stmt, Section)):
         sections[stmt.name] = list()
         section_cur = stmt.name
      else:
         sections[section_cur].append(stmt)

   # Validate order within each section. This uses an O(n²) all-to-all
   # comparison so we can consider class inheritance.
   for (section_name, section) in sections.items():
      #print("analyzing section %s" % section_name)
      for i in range(len(section) - 1):
         before = section[i]
         for j in range(i + 1, len(section)):
            after = section[j]
            if (not (   inherits(after, before)
                     or sort_key(before) <= sort_key(after))):
               #print(before.lineno, sort_key(before), sort_key(after))
               FAILO(before, after)


validate(tree.body)
for stmt in tree.body:
   if (isinstance(stmt, ast.ClassDef)):
      validate(stmt.body)
print("total errors: %d" % error_ct)
sys.exit(int(error_ct != 0))
