获取GO的层级

#encoding:utf-8
'''
Created on 2024.7.29

@author:
'''

import re
import argparse

parser = argparse.ArgumentParser(description='解析Go官网的obo文件,可以获取层级信息')
parser.add_argument('--term_list', type=str, help='关注的通路id')
argv = vars(parser.parse_args())

term_list = argv['term_list'].strip()


GO_level1 = {"biological_process":"GO:0008150", "molecular_function":"GO:0003674", "cellular_component":"GO:0005575"}

class GOBase(object):
    def __init__(self,_id):
        self._id=_id
        self.alt_ids=[]
        self.name=''
        self.namespace=''
        self.parent=None
        self.level=-1
        self.allParents=None

class ObOs(object):
    def __init__(self,path):
        self.path=path
        self.map={}
        self.parseObO()
    def parseObO(self):
        f=open(self.path)
        lines=f.readlines()
        f.close()
        _goTxt=[]
        flag=False
        for line in lines:
            line=line.rstrip('\n').strip()
            if flag:
                _goTxt.append(line)
            if flag and line=='':
                self.parseGO(_goTxt)
                _goTxt=[]
                flag=False
            if line.find('[Term]')==0:
                flag=True
    def parseGO(self,_goText):
        _id=None
        _name=''
        _namespace=''
        _is_as=[]
        _alt_ids=[]
        for _txt in _goText:
            if _txt.find('id:')==0:
                _id=_txt[_txt.find('GO'):_txt.find('GO')+10]
            elif _txt.find('name:')==0:
                _name=_txt[5:len(_txt)].rstrip().lstrip()
            elif _txt.find('namespace:')==0:
                _namespace=_txt[10:len(_txt)].rstrip().lstrip()
            elif _txt.find('alt_id:')==0:
                _alt_ids.append(_txt[_txt.find('GO'):_txt.find('GO')+10])
            elif _txt.find('is_a:')==0 or _txt.find('relationship:')==0:
                _is_as.append(_txt[_txt.find('GO'):_txt.find('GO')+10])
        if _id:
            _go=None
            if  _id in self.map:
                _go=self.map[_id]
            else:
                _go=GOBase(_id)
            _go.name=_name
            _go.namespace=_namespace
            _go.parent=self.parseParent(_is_as)
            _go.alt_ids=_alt_ids
            self.map[_id]=_go
            if len(_alt_ids)>0:
                for _alt in _alt_ids:
                    self.map[_alt]=_go
    def parseParent(self,is_as):
        __parent=[]
        for isa in is_as:
            if isa  not in self.map:
                cGo=GOBase(isa)
                self.map[isa]=cGo
            __parent.append(isa)
        return __parent

    def getLevel(self,_id):
        _min=100000
        _go=self.map[_id]
        if _go.level>0:
            pass
        elif len(_go.parent)==0:
            _go.level=1
        else:
            for g in _go.parent:
                lev=self.getLevel(g)
                if _min>lev:
                    _min=lev
            _go.level=_min+1
        return _go.level
    def getAllParent(self,_id):
        _prs=[_id]
        _go=self.map[_id]
        if not _go.allParents is None:
            return _go.allParents
        if _go.parent is None or len(_go.parent)==0:
            _go.allParents=_prs
            return _go.allParents
        for g in _go.parent:
            ap=self.getAllParent(g)
            _prs.extend(ap)
        #_go.allParents=list(set(_prs))
        _go.allParents=list(_prs)
        return _go.allParents
    def get_name(self, _id):
        _go=self.map[_id]
        return _go.name
    def get_alt_ids(self, _id):
        _go=self.map[_id]
        return _go.alt_ids

if __name__ == '__main__':
    obo='/TJPROJ6/RNA_SH/personal_dir/lizhengnan/scripts/GO/go-basic.obo'
    ob=ObOs(obo)
go_level1 = ["GO:0008150","GO:0003674","GO:0005575"]
with open(term_list) as f:
    sheet = f.readlines()

def print_res(go_id, _res):
    #_res = ob.getAllParent(go_id)
    _sub_prs = _res[:1]
    for g in _res[1:]:
        if g not in go_level1:
            _sub_prs.insert(0, g)
        else:
            sub_prs.insert(0, g)
            for gg in sub_prs:
                print(gg, ob.get_name(gg), sep='\t', end="\t")
            print("\t".join(line[1:]))
            _sub_prs = res[:1]
for l in sheet[1:]:
    line = l.strip().split("\t")
    go_id = line[0]
    res = ob.getAllParent(go_id)
    sub_prs = res[:1]
    alt_ids = ob.get_alt_ids(go_id)
    #print(alt_ids)
    #print(res)
    if len(res) == 1 and len (alt_ids)==0:
        print("\t".join(line))
    elif len(res) == 1 and len (alt_ids)>0:
        for al_id in alt_ids:
            if len(ob.getAllParent(al_id))> 1:
                #print_res(al_id)
                tmp_res = ob.getAllParent(al_id)
                tmp_res[0] = go_id
                print_res(al_id, tmp_res)
                break
            else:
                print("\t".join(line))
    else:
        for g in res[1:]:
            if g not in go_level1:
                sub_prs.insert(0, g)
            else:
                sub_prs.insert(0, g)
                for gg in sub_prs:
                    print(gg, ob.get_name(gg), sep='\t', end="\t")
                print("\t".join(line[1:]))
            #print("\t".join(sub_prs))
                sub_prs = res[:1]