Spaces:
Build error
Build error
| import numpy as np | |
| import timeit | |
| import torch | |
| from transformers import AutoTokenizer | |
| #Node | |
| class Node: | |
| #init | |
| def __init__(self, letter="", parent=None, root=False, last=False, depth=0, states=8): | |
| self.exploitation_score = 0 # Exploitaion score | |
| self.visits = 1 #How many visits | |
| self.letter = letter #This node's letter | |
| self.parent = parent #This node's parent node | |
| self.states = states #How many states in node | |
| self.children = np.array([None for _ in range(self.states)]) #This node's children | |
| self.children_stat = np.zeros(self.states, dtype=bool) #Which stat are expanded | |
| self.root = root # Is root? boolean | |
| self.last = last # Is last node? | |
| self.depth = depth # My depth | |
| self.letters =["A_", "C_", "G_", "T_", "_A", "_C", "_G", "_T"] | |
| #next_node | |
| def next_node(self, child=0): #Return next node | |
| assert self.children_stat[child] == True, "No child in here." | |
| return self.children[child] | |
| #back_parent | |
| def back_parent(self): #Go back to parent | |
| return self.parent, letters_map[self.letter] | |
| #generate_child | |
| def generate_child(self, child=0, last=False): #Generate child | |
| assert self.children_stat[child] == False, "Already tree generated child at here" | |
| self.children[child] = Node(letter=self.letters[child], parent=self, last=last, depth=self.depth+1, states=self.states) #New node | |
| self.children_stat[child] = True #Stat = True | |
| return self.children[child] | |
| #backpropagation | |
| def backpropagation(self, score=0): | |
| self.visits += 1 # +1 to visit | |
| if self.root == True: # if root, then stop | |
| return self.exploitation_score | |
| else: | |
| self.exploitation_score += score #Add score to exploitation score | |
| return self.parent.backpropagation(score=score) #Backpropagation to parent node | |
| #UCT | |
| def UCT(self): | |
| return (self.exploitation_score / self.visits) + np.sqrt(np.log(self.parent.visits) / (2 * self.visits)) #UCT score | |
| #MCTS | |
| class MCTS: | |
| def __init__(self, target_encoded, depth=20, iteration=1000, states=8, target_protein="", device='cpu', esm_alphabet=None): | |
| self.states = states #How many states | |
| self.root = Node(letter="", parent=None, root=True, last=False, states=self.states) #root node | |
| self.depth = depth #Maximum depth | |
| self.iteration = iteration #iteration for expand | |
| self.target_protein = target_protein #target protein's amino acid sequence | |
| self.device = device | |
| self.encoded_targetprotein = target_encoded | |
| self.base = "" | |
| self.candidate = "" | |
| self.letters =["A_", "C_", "G_", "T_", "_A", "_C", "_G", "_T"] | |
| self.esm_alphabet = esm_alphabet | |
| self.nt_tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-v2-50m-multi-species", trust_remote_code=True) | |
| def make_candidate(self, classifier): | |
| now = self.root | |
| n = 0 # rounds | |
| start_time = timeit.default_timer() #timer start | |
| while len(self.base) < self.depth * 2: #If now is last node, then stop | |
| n += 1 | |
| print(n, "round start!!!") | |
| for _ in range(self.iteration): | |
| now = self.select(classifier, now=now) #Select & Expand | |
| terminate_time = timeit.default_timer() | |
| time = terminate_time-start_time | |
| base = self.find_best_subsequence() #Find best subsequence | |
| self.base = base | |
| # print("best subsequence:", base) | |
| # print("Depth:", int(len(base)/2)) | |
| # print("%02d:%02d:%2f" % ((time//3600), (time//60)%60, time%60)) | |
| # print("=" * 80) | |
| self.root = Node(letter="", parent=None, root=True, last=False, states=self.states, depth=len(self.base)/2) | |
| now = self.root | |
| self.candidate = self.base | |
| return self.candidate | |
| #selection | |
| def select(self, classifier, now=None): | |
| if now.depth == self.depth: #If last node, then stop | |
| return self.root | |
| next_node = 0 | |
| if np.sum(now.children_stat) == self.states: #If every child is expanded, then go to best child | |
| best = 0 | |
| for i in range(self.states): | |
| if best < now.children[i].UCT(): | |
| next_node = i | |
| best = now.children[i].UCT() | |
| else: #If not, then random | |
| next_node = np.random.randint(0, self.states) | |
| if now.children_stat[next_node] == False: #If selected child is not expanded, then expand and simulate | |
| next_node = self.expand(classifier, child=next_node, now=now) | |
| return self.root #start iteration at this node | |
| return now.next_node(child=next_node) | |
| #expand | |
| def expand(self, classifier, child=None, now=None): | |
| last = False | |
| if now.depth == (self.depth-1): #If depth of this node is maximum depth -1, then next node is last | |
| last = True | |
| expanded_node = now.generate_child(child=child, last=last) #Expand | |
| score = self.simulate(classifier, target=expanded_node) #Simulate | |
| expanded_node.backpropagation(score=score) #Backporpagation | |
| return child | |
| #simulate | |
| def simulate(self, classifier, target=None): | |
| now = target #Target node | |
| sim_seq = "" | |
| while now.root != True: #Parent's letters | |
| sim_seq = now.letter + sim_seq | |
| now = now.parent | |
| sim_seq = self.base + sim_seq | |
| for i in range((self.depth * 2) - len(sim_seq)): #Random child letters | |
| r = np.random.randint(0,self.states) | |
| sim_seq += self.letters[r] | |
| sim_seq = self.reconstruct(sim_seq) | |
| scores = [] | |
| classifier.eval().to('cuda') | |
| with torch.no_grad(): | |
| sim_seq = np.array([sim_seq]) | |
| apta_toks = self.nt_tokenizer.batch_encode_plus(sim_seq, return_tensors='pt', padding='max_length', max_length=275)['input_ids'] | |
| apta_attention_mask = apta_toks != self.nt_tokenizer.pad_token_id | |
| prot_attention_mask = self.encoded_targetprotein != self.esm_alphabet.padding_idx | |
| score, _, _, _ = classifier(apta_toks.to('cuda'), self.encoded_targetprotein.to('cuda'), apta_attention_mask.to('cuda'), prot_attention_mask.to('cuda')) | |
| return score | |
| #recommend | |
| def get_candidate(self): | |
| return self.reconstruct(self.candidate) | |
| def find_best_subsequence(self): | |
| now = self.root | |
| stop = False | |
| base = self.base | |
| for _ in range((self.depth*2) - len(base)): | |
| best = 0 | |
| next_node = 0 | |
| for j in range(self.states): | |
| if now.children_stat[j] == True: | |
| if best < now.children[j].UCT(): | |
| next_node = j | |
| best = now.children[j].UCT() | |
| now = now.next_node(child=next_node) | |
| base += now.letter | |
| # if current node has no expanded children, stop reconstructing. | |
| if np.sum(now.children_stat) == 0: | |
| break | |
| return base | |
| #reconstruct | |
| def reconstruct(self, seq=""): | |
| r_seq = "" | |
| for i in range(0, len(seq), 2): | |
| if seq[i] == '_': | |
| r_seq = r_seq + seq[i+1] | |
| else: | |
| r_seq = seq[i] + r_seq | |
| return r_seq | |
| def reset(self): | |
| self.base = "" | |
| self.candidate = "" | |
| self.root = Node(letter="", parent=None, root=True, last=False, states=self.states) |