Source code for lfd.action_selection

from __future__ import division

import numpy as np

from lfd.environment import simulation_object
from lfd.mmqe.search import beam_search
from lfd.rapprentice.knot_classifier import isKnot as is_knot


[docs]class ActionSelection(object): def __init__(self, registration_factory): """Inits ActionSelection Args: registration_factory: RegistrationFactory """ self.registration_factory = registration_factory
[docs] def plan_agenda(self, scene_state): """Plans an agenda of demonstrations for the given scene_state Args: scene_state: SceneState of the scene for which the agenda is returned Returns: An agenda, which is a list of demonstration names, and a list of the values of the respective demonstrations """ raise NotImplementedError
[docs]class GreedyActionSelection(ActionSelection):
[docs] def plan_agenda(self, scene_state, timestep): action2q_value = self.registration_factory.batch_cost(scene_state) q_values, agenda = zip(*sorted([(q_value, action) for (action, q_value) in action2q_value.items()])) # Return false for goal not found return (agenda, q_values), False
[docs]class FeatureActionSelection(ActionSelection): def __init__(self, registration_factory, features, actions, demos, width, depth, simulator=None, lfd_env=None): self.features = features self.actions = actions.keys() # self.features.set_name2ind(self.actions) self.demos = demos self.width = width self.depth = depth self.transferer = simulator self.lfd_env = lfd_env super(FeatureActionSelection, self).__init__(registration_factory)
[docs] def plan_agenda(self, scene_state, timestep): def evaluator(state, ts): try: score = np.dot(self.features.features(state, timestep=ts), self.features.weights) + self.features.w0 except: return -np.inf*np.r_[np.ones(len(self.features.weights))] # if np.max(score) > -.2: # import ipdb; ipdb.set_trace() return score def simulate_transfer(state, action, next_state_id): aug_traj=self.transferer.transfer(self.demos[action], state, plotting=False) self.lfd_env.execute_augmented_trajectory(aug_traj, step_viewer=0) result_state = self.lfd_env.observe_scene() # Get the rope simulation object and determine if it's a knot for sim_obj in self.lfd_env.sim.sim_objs: if isinstance(sim_obj, simulation_object.RopeSimulationObject): rope_sim_obj = sim_obj break rope_knot = is_knot(rope_sim_obj.rope.GetControlPoints()) return (result_state, next_state_id, rope_knot) return beam_search(scene_state, timestep, self.features.src_ctx.seg_names, simulate_transfer, evaluator, self.lfd_env.sim, width=self.width, depth=self.depth)