def remove_rare_variants():
num_rare_derived=0
threshold=0.01
sites = msprime.SiteTable()
mutations = msprime.MutationTable()
for tree in ts.trees():
for site in tree.sites():
mut = site.mutations[0]
freq = tree.num_samples(mut.node) / N
if freq > threshold:
num_rare_derived += 1
site_id = sites.add_row(
position=site.position,
ancestral_state=site.ancestral_state)
mutations.add_row(
site=site_id, node=mut.node, derived_state=mut.derived_state)
tables = ts.dump_tables()
new_ts = msprime.load_tables(
nodes=tables.nodes, edges=tables.edges, sites=sites, mutations=mutations)
return new_ts