#!/usr/bin/env python

##
## palinka2mmax.py -- converts NP4E palinka files to MMAX2 format
## version: 2009-01-30
## 

## Copyright 2009 Yannick Versley / CiMeC Univ. Trento
## 
## Licensed under the Apache License, Version 2.0 (the "License");
## you may not use this file except in compliance with the License.
## You may obtain a copy of the License at
## 
## http://www.apache.org/licenses/LICENSE-2.0
## 
## Unless required by applicable law or agreed to in writing, software
## distributed under the License is distributed on an "AS IS" BASIS,
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
## See the License for the specific language governing permissions and
## limitations under the License.

## NOTE:
## if you want to use this script to convert your own Palinka annotation
## to MMAX2 format, you need to do the following:
## * depending on the link elements you are using, change the
##   elem.tag in ['COREF'] (in palina2markables) to the list
##   of coreference relations you are using. The subsequent check
##   (of TYPE_REF) can be deleted if you don't have relation subtypes.
## * cut_markable is called whenever we think that the minimum span of
##   a markable should end, i.e., when we see postmodification,
##   appositions or the like. Don't worry about this if you are not using
##   BART -- BART uses the min_ids attribute to help match gold and system
##   markables against each other.
## * the script is called with the DIRECTORY NAME where your PALINKA files
##   are found and the DIRECTORY NAME where the MMAX2 files should end up.
##   The MMAX2 directory has the following internal structure:
##   + Basedata/   contains all the ..._words.xml files with tokens
##   + markables/  contains the markable files.
##   To open the files in MMAX2, you need additional boilerplate files:
##   the common/ directory, as well as Basedata/tokens.dtd and
##   markables/markables.dtd

import sys
import re
from glob import glob
import xml.etree.cElementTree as etree
from xml.sax.saxutils import quoteattr,escape
import os.path

punct_re=re.compile("[\(\)\.,\"]|'s")

## TODO: use FUNC labels to derive min_ids property
def maybe_int(s):
    try:
        return int(s)
    except ValueError:
        return s

def frob_boundaries(ta,start_pos,end_pos):
    while punct_re.match(ta[start_pos]):
        start_pos+=1
    while punct_re.match(ta[end_pos-1]):
        end_pos-=1
    return (start_pos,end_pos)

def cut_markable(m,pos):
    if not 'min_ids' in m[1]:
        m[1]['min_ids']=(m[0],pos)

doc_markup={'P':('section',{}),'S':('sentence',{})}

def palinka2markables(src):
    # tokens
    ta=[]
    # mmax2 markables
    aa=[]
    markables=[]
    other=[]
    other_id=0
    pos=0
    for evt,elem in etree.iterparse(src, events=("start","end")):
        if evt=='start':
            if elem.tag=='MARKABLE':
                markables.append((pos,{}))
            elif elem.tag in doc_markup:
                other.append((doc_markup[elem.tag][0],pos,{}))
        else:
            if elem.tag=='MARKABLE':
                start_pos,attrs=markables.pop()
                markable_id=maybe_int(elem.attrib['ID'])
                try:
                    comment=elem.attrib['COMMENT']
                except KeyError:
                    pass
                aa.append(('coref',markable_id,
                           attrs,
                           start_pos,pos))
            elif elem.tag in doc_markup:
                lvl,start_pos,attrs=other.pop()
                if pos>start_pos:
                    other_id+=1
                    aa.append((lvl,other_id,
                               attrs,start_pos,pos))
            elif elem.tag in ['COREF']:
                markables[-1][1]['COREF']=maybe_int(elem.attrib['SRC'])
                if elem.attrib['TYPE_REF'] in ['BRACKETED_TEXT','APPOSITION']:
                    if len(markables)>=2:
                        cut_markable(markables[-2],pos)
            elif elem.tag=='W':
                txt=elem.text
                if txt.endswith("'s") and len(txt)>2:
                    ta.append(txt[:-2])
                    pos+=1
                    txt=txt[-2:]
                elif '_' in txt:
                    words=txt.split('_')
                    if words[0]:
                        # eliminate '' in can _ ''
                        words=[w for w in words if w!='']
                        for txt0 in words[:-1]:
                            ta.append(txt0)
                            pos+=1
                        txt=words[-1]
                ta.append(txt)
                pos+=1
                if elem.attrib['POS'] in ['PREP','EN'] and len(markables)>=1:
                    cut_markable(markables[-1],pos-1)
            elif elem.tag=='PUNCT':
                txt=elem.text
                if txt in ['"_','_"']:
                    txt='"'
                elif txt in ["'_","_'"]:
                    txt="'"
                elif txt=='$--':
                    txt='--'
                ta.append(txt)
                pos+=1
        assert pos==len(ta)
    return (ta,aa)

def write_basedata(dirname,basename,ta):
    f=file(os.path.join(dirname,'Basedata/%s_words.xml'%(basename,)),'w')
    f.write('''<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE words SYSTEM "words.dtd">
<words>
''')
    for i,w in enumerate(ta):
        print >>f,'  <word id="word_%d">%s</word>'%(i+1,escape(w))
    print >>f,"</words>"
    f.close()

def make_span(a,b):
    a1=a+1
    if (a1==b):
        return 'word_%d'%(a1,)
    else:
        return 'word_%d..word_%d'%(a1,b)

def write_markables(dirname,basename,aa):
    fs={}
    for alvl,m_id,attrs,start_pos,end_pos in aa:
        if alvl in fs:
            f=fs[alvl]
        else:
            f=file(os.path.join(dirname,'markables/%s_%s_level.xml'%(basename,alvl)),'w')
            f.write('''<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE markables SYSTEM "markables.dtd">
<markables xmlns="www.eml.org/NameSpaces/%s">
'''%(alvl,))
            fs[alvl]=f
        f.write('<markable id="markable_%d" span="%s" mmax_level="%s"'%(
           m_id,make_span(start_pos,end_pos),alvl))
        # write attributes
        for k,v in sorted(attrs.items()):
            if type(v)==tuple:
                assert len(v)==2
                f.write(' %s="%s"'%(k,make_span(v[0],v[1])))
            else:
                f.write(' %s=%s'%(k,quoteattr(str(v))))
        f.write('/>\n')
    for alvl,f in fs.items():
        f.write('</markables>\n')
        f.close()

def get_basename(fname):
    b=os.path.basename(fname)
    while b.endswith('.xml'): b=b[:-4]
    while b.endswith('-done'): b=b[:-5]
    while b.endswith('done'): b=b[:-4]
    return b

def links2sets(aa):
    sets={}
    elements={}
    next_set_id=0
    for alvl,m_id,attrs,start_pos,end_pos in aa:
        if alvl!='coref' or 'COREF' not in attrs:
            continue
        link=(maybe_int(m_id),maybe_int(attrs['COREF']))
        ks=[sets[mid] for mid in link if mid in sets]
        if len(ks)==0:
            set_id="set_%d"%(next_set_id,)
            elements[set_id]=[]
            next_set_id+=1
        elif len(ks)==1:
            set_id=ks[0]
        else:
            a=[]
            set_id=ks[0]
            for set2 in ks:
                elm=elements[set2]
                elements[set2]=[]
                a.extend(elm)
                for e in elm:
                    sets[e]=set_id
            elements[set_id]=a
        for e in link:
            old_setid=sets.get(e,'')
            sets[e]=set_id
            if old_setid!=set_id:
                elements[set_id].append(e)
    for anno in aa:
        if anno[0]!='coref' or anno[1] not in sets:
            continue
        anno[2]['coref_set']=sets[anno[1]]

def filter_coref(aa):
    aa_new=[]
    for anno in aa:
        wanted=True
        if anno[0]=='coref' and 'coref_set' not in anno[2]:
            wanted=False
        if wanted:
            if 'COREF' in anno[2]:
                anno[2]['dir_antecedent']='markable_%s'%(anno[2]['COREF'],)
                del anno[2]['COREF']
            aa_new.append(anno)
    return aa_new

def correct_boundaries(ta,aa):
    aa_new=[]
    for anno in aa:
        a=list(anno)
        if anno[0]!='coref':
            aa_new.append(anno)
            continue
        a[3:5]=frob_boundaries(ta,*a[3:5])
        if 'min_ids' in a[2]:
            a[2]['min_ids']=frob_boundaries(ta,*a[2]['min_ids'])
        aa_new.append(tuple(a))
    return aa_new

def write_dotmmax(dirname,base):
    f=file(os.path.join(dirname,'%s.mmax'%(base,)),'w')
    f.write('''<?xml version="1.0"?>
<mmax_project>
<turns></turns>
<words>%s_words.xml</words>
<gestures></gestures>
<keyactions></keyactions>
<views>
<stylesheet>muc_style.xsl</stylesheet>
</views>
</mmax_project>
'''%(base,))
    f.close()
    
if __name__=='__main__':
    input_names=glob(os.path.join(sys.argv[1],'*.xml'))
    for fname in input_names:
        base=get_basename(fname)
        ta,aa=palinka2markables(fname)
        print "%s => %s"%(fname,base)
        links2sets(aa)
        aa=filter_coref(aa)
        aa=correct_boundaries(ta,aa)
        write_basedata(sys.argv[2],base,ta)
        write_markables(sys.argv[2],base,aa)
        write_dotmmax(sys.argv[2],base)

