--
You received this message because you are subscribed to the Google Groups "Cap'n Proto" group.
To unsubscribe from this group and stop receiving emails from it, send an email to capnproto+...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/capnproto/ac8e4eba-11e9-44cc-9095-4313e4b7e544n%40googlegroups.com.
class DagDataset(Dataset):
def __init__(self, path2dataprep, path2hash2idx, split):
self.split = split
self.path2dataprep = path2dataprep
db = torch.load(self.path2dataprep)
self.data_prep = db['data_prep']
self.list_files_current_split = self.data_prep.flatten_lst_files_split[self.split]
self.list_counts_current_split = self.data_prep.counts[self.split]
self.list_cummulative_sum_current_split = self.data_prep.cummulative_sum[self.split]
self.list_cummulative_end_index_current_split = self.data_prep.cummulative_end_index[self.split]
self.length = sum(self.list_counts_current_split)
#
self.path2hash2idx = path2hash2idx
db = torch.load(self.path2hash2idx)
self.hash2idx = db['hash2idx']
def __len__(self):
return self.length
def __getitem__(self, idx: int) -> DagNode:
# gets the file idx for the value we want
file_idx = bisect.bisect_left(self.list_cummulative_end_index_current_split, idx)
# now get the actual file name
file_name = self.list_files_current_split[file_idx]
# get the file with proof steps
file_name = self.convert_to_local_home_path(file_name)
f = open(file_name)
current_dag_file = dag_api_capnp.Dag.read_packed(f, traversal_limit_in_words=2 ** 64 - 1)
# current_dag_file = dag_api_capnp.Dag.read_packed(f)
# - global idx 2 idx relative to this file
prev_cummulative_sum = self.get_previous_cummulative_sum(file_idx)
idx_rel_this_file = idx - prev_cummulative_sum
# - data point
node_idx = current_dag_file.proofSteps[idx_rel_this_file].node
tactic_hash = current_dag_file.proofSteps[idx_rel_this_file].tactic
tactic_label = self.hash2idx[tactic_hash]
# - get Node obj
node_ref = NodeRef(node_idx, 0) # indicates it's in the current file this cased named current_dag_file
node = DagNode(current_dag_file, node_ref)
# node = current_dag_file
f.close()
return node, tactic_label
# @profile
def train(self, n_epoch):
import time
self.tactic_predictor.train()
avg_loss = AverageMeter('train loss')
avg_acc = AverageMeter('train accuracy')
# iterations = len(self.dataloaders['train'])
# bar = ProgressBar(max_value=iterations)
self.dataloaders['train'] = iter(self.dataloaders['train'])
# for i, data_batch in enumerate(self.dataloaders['train']):
for i in range(len(self.dataloaders['train'])):
data_batch = next(self.dataloaders['train'])
data_batch = process_batch_ddp(self.opts, data_batch)
# loss, logits = self.tactic_predictor(data_batch)
# acc = accuracy(output=logits, target=data_batch[1])
# avg_loss.update(loss, self.opts.batch_size)
# avg_acc.update(acc, self.opts.batch_size)
self.log(f'{i=}')
#self.log(f"{i=}: {loss=}")
# self.optimizer.zero_grad()
# loss.backward() # each process synchronizes it's gradients in the backward pass
# self.optimizer.step() # the right update is done since all procs have the right synced grads
# del loss
# del logits
# del data_batch
gc.collect()
# bar.update(i)
if i >= 10:
time.sleep(2)
sys.exit()
return avg_loss.item(), avg_acc.item()
# time.sleep(1)
loss, logits = self.tactic_predictor(data_batch)
# time.sleep(1)
# self.mem_test()
# time.sleep(1)
# acc = accuracy(output=logits, target=data_batch[1])
# avg_loss.update(loss, self.opts.batch_size)
# avg_acc.update(acc, self.opts.batch_size)
self.log(f'{i=}')
self.log(f"{i=}: {loss=}")
# time.sleep(1)
# self.mem_test()
# time.sleep(1)
# self.optimizer.zero_grad()
# loss.backward() # each process synchronizes it's gradients in the backward pass
# self.optimizer.step() # the right update is done since all procs have the right synced grads
# time.sleep(1)
# self.mem_test()
# time.sleep(1)
del loss
del logits
del data_batch
gc.collect()